Spaces:
Running on Zero
Running on Zero
| # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # coding: utf-8 | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from torchvision import transforms | |
| from torchvision.transforms import functional as F | |
| from torchvision.transforms import InterpolationMode, Compose, Normalize | |
| from .video.transforms.na_resize import NaResize | |
| from .video.transforms.divisible_crop import DivisibleCrop | |
| from .video.transforms.rearrange import Rearrange | |
| class MaxLongEdgeMinShortEdgeResize(torch.nn.Module): | |
| """Resize the input image so that its longest side and shortest side are within a specified range, | |
| ensuring that both sides are divisible by a specified stride. | |
| Args: | |
| max_size (int): Maximum size for the longest edge of the image. | |
| min_size (int): Minimum size for the shortest edge of the image. | |
| stride (int): Value by which the height and width of the image must be divisible. | |
| max_pixels (int): Maximum pixels for the full image. | |
| interpolation (InterpolationMode): Desired interpolation enum defined by | |
| :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. | |
| If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, | |
| ``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported. | |
| The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted. | |
| antialias (bool, optional): Whether to apply antialiasing (default is True). | |
| """ | |
| def __init__( | |
| self, | |
| max_size: int, | |
| min_size: int, | |
| stride: int, | |
| max_pixels: int, | |
| interpolation=InterpolationMode.BICUBIC, | |
| antialias=True | |
| ): | |
| super().__init__() | |
| self.max_size = max_size | |
| self.min_size = min_size | |
| self.stride = stride | |
| self.max_pixels = max_pixels | |
| self.interpolation = interpolation | |
| self.antialias = antialias | |
| def _make_divisible(self, value, stride): | |
| """Ensure the value is divisible by the stride.""" | |
| return max(stride, int(round(value / stride) * stride)) | |
| def _apply_scale(self, width, height, scale): | |
| new_width = round(width * scale) | |
| new_height = round(height * scale) | |
| new_width = self._make_divisible(new_width, self.stride) | |
| new_height = self._make_divisible(new_height, self.stride) | |
| return new_width, new_height | |
| def forward(self, img, img_num=1): | |
| """ | |
| Args: | |
| img (PIL Image): Image to be resized. | |
| img_num (int): Number of images, used to change max_tokens. | |
| Returns: | |
| PIL Image or Tensor: Rescaled image with divisible dimensions. | |
| """ | |
| if isinstance(img, torch.Tensor): | |
| height, width = img.shape[-2:] | |
| else: | |
| width, height = img.size | |
| scale = min(self.max_size / max(width, height), 1.0) | |
| scale = max(scale, self.min_size / min(width, height)) | |
| new_width, new_height = self._apply_scale(width, height, scale) | |
| # Ensure the number of pixels does not exceed max_pixels | |
| if new_width * new_height > self.max_pixels / img_num: | |
| scale = self.max_pixels / img_num / (new_width * new_height) | |
| new_width, new_height = self._apply_scale(new_width, new_height, scale) | |
| # Ensure longest edge does not exceed max_size | |
| if max(new_width, new_height) > self.max_size: | |
| scale = self.max_size / max(new_width, new_height) | |
| new_width, new_height = self._apply_scale(new_width, new_height, scale) | |
| return F.resize(img, (new_height, new_width), self.interpolation, antialias=self.antialias) | |
| class ImageTransform: | |
| def __init__( | |
| self, | |
| max_image_size, | |
| min_image_size, | |
| image_stride, | |
| max_pixels=14*14*9*1024, | |
| image_mean=[0.5, 0.5, 0.5], | |
| image_std=[0.5, 0.5, 0.5] | |
| ): | |
| self.stride = image_stride | |
| self.resize_transform = MaxLongEdgeMinShortEdgeResize( | |
| max_size=max_image_size, | |
| min_size=min_image_size, | |
| stride=image_stride, | |
| max_pixels=max_pixels, | |
| ) | |
| self.to_tensor_transform = transforms.ToTensor() | |
| self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True) | |
| def __call__(self, img, img_num=1): | |
| img = self.resize_transform(img, img_num=img_num) | |
| img = self.to_tensor_transform(img) | |
| img = self.normalize_transform(img) | |
| return img | |
| class VideoTransform: | |
| def __init__( | |
| self, | |
| resolution=640, | |
| mode="area", | |
| divisible_crop_size=16, | |
| aspect_ratios=("21:9", "16:9", "4:3", "1:1", "3:4", "9:16"), | |
| stride_spatial=16, | |
| stride_temporal=4, | |
| mean=0.5, | |
| std=0.5, | |
| **kwargs | |
| ): | |
| self.transform = Compose( | |
| [ | |
| NaResize( | |
| resolution=resolution, | |
| mode=mode, | |
| downsample_only=True, | |
| stride=stride_spatial, | |
| # NOTE: aspect_ratios are only for `bucket` resize. | |
| aspect_ratios=aspect_ratios, | |
| ), | |
| DivisibleCrop(divisible_crop_size), | |
| Normalize(mean, std), | |
| Rearrange("t c h w -> c t h w"), | |
| ] | |
| ) | |
| # self.stride = divisible_crop_size if isinstance(divisible_crop_size, int) else divisible_crop_size[0] | |
| self.stride_spatial = stride_spatial | |
| self.stride_temporal = stride_temporal | |
| def __call__(self, video): | |
| return self.transform(video) | |
| class VisualTransform: | |
| def __init__( | |
| self, | |
| max_frame_size, | |
| min_frame_size, | |
| image_stride, | |
| max_pixels=14*14*9*1024, | |
| image_mean=[0.5, 0.5, 0.5], | |
| image_std=[0.5, 0.5, 0.5] | |
| ): | |
| self.stride = image_stride | |
| self.resize_transform = MaxLongEdgeMinShortEdgeResize( | |
| max_size=max_frame_size, | |
| min_size=min_frame_size, | |
| stride=image_stride, | |
| max_pixels=max_pixels, | |
| ) | |
| self.to_tensor_transform = transforms.ToTensor() | |
| self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True) | |
| def _process_single(self, img, img_num=1): | |
| img = self.resize_transform(img, img_num=img_num) | |
| img = self.to_tensor_transform(img) | |
| img = self.normalize_transform(img) | |
| return img | |
| def __call__(self, img, img_num=1): | |
| # --- 视频序列处理 --- | |
| if isinstance(img, (list, tuple)): | |
| # List of PIL.Image or tensors | |
| out = torch.stack([self._process_single(frame, img_num=img_num) for frame in img]) # [T, C, H, W] | |
| out = out.permute(1, 0, 2, 3) # [C, T, H, W] | |
| return out | |
| elif isinstance(img, np.ndarray) and img.ndim == 4: | |
| # numpy array: [T, H, W, C] or [T, C, H, W] | |
| frames = [img[i] for i in range(img.shape[0])] | |
| processed_frames = [self._process_single(Image.fromarray(frame) if frame.shape[-1] in [3, 4] else frame, img_num=img_num) | |
| for frame in frames] | |
| out = torch.stack(processed_frames) # [T, C, H, W] | |
| out = out.permute(1, 0, 2, 3) # [C, T, H, W] | |
| return out | |
| elif isinstance(img, torch.Tensor) and img.ndim == 4: | |
| # torch tensor: [T, C, H, W] or [T, H, W, C] | |
| frames = [img[i] for i in range(img.shape[0])] | |
| processed_frames = [self._process_single(frame, img_num=img_num) for frame in frames] | |
| out = torch.stack(processed_frames) # [T, C, H, W] | |
| out = out.permute(1, 0, 2, 3) # [C, T, H, W] | |
| return out | |
| else: | |
| # 单帧 | |
| return self._process_single(img, img_num=img_num) | |