Lance / data /transforms.py
Nayefleb's picture
Upload folder using huggingface_hub
8b306b3 verified
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf-8
from PIL import Image
import numpy as np
import torch
from torchvision import transforms
from torchvision.transforms import functional as F
from torchvision.transforms import InterpolationMode, Compose, Normalize
from .video.transforms.na_resize import NaResize
from .video.transforms.divisible_crop import DivisibleCrop
from .video.transforms.rearrange import Rearrange
class MaxLongEdgeMinShortEdgeResize(torch.nn.Module):
"""Resize the input image so that its longest side and shortest side are within a specified range,
ensuring that both sides are divisible by a specified stride.
Args:
max_size (int): Maximum size for the longest edge of the image.
min_size (int): Minimum size for the shortest edge of the image.
stride (int): Value by which the height and width of the image must be divisible.
max_pixels (int): Maximum pixels for the full image.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported.
The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted.
antialias (bool, optional): Whether to apply antialiasing (default is True).
"""
def __init__(
self,
max_size: int,
min_size: int,
stride: int,
max_pixels: int,
interpolation=InterpolationMode.BICUBIC,
antialias=True
):
super().__init__()
self.max_size = max_size
self.min_size = min_size
self.stride = stride
self.max_pixels = max_pixels
self.interpolation = interpolation
self.antialias = antialias
def _make_divisible(self, value, stride):
"""Ensure the value is divisible by the stride."""
return max(stride, int(round(value / stride) * stride))
def _apply_scale(self, width, height, scale):
new_width = round(width * scale)
new_height = round(height * scale)
new_width = self._make_divisible(new_width, self.stride)
new_height = self._make_divisible(new_height, self.stride)
return new_width, new_height
def forward(self, img, img_num=1):
"""
Args:
img (PIL Image): Image to be resized.
img_num (int): Number of images, used to change max_tokens.
Returns:
PIL Image or Tensor: Rescaled image with divisible dimensions.
"""
if isinstance(img, torch.Tensor):
height, width = img.shape[-2:]
else:
width, height = img.size
scale = min(self.max_size / max(width, height), 1.0)
scale = max(scale, self.min_size / min(width, height))
new_width, new_height = self._apply_scale(width, height, scale)
# Ensure the number of pixels does not exceed max_pixels
if new_width * new_height > self.max_pixels / img_num:
scale = self.max_pixels / img_num / (new_width * new_height)
new_width, new_height = self._apply_scale(new_width, new_height, scale)
# Ensure longest edge does not exceed max_size
if max(new_width, new_height) > self.max_size:
scale = self.max_size / max(new_width, new_height)
new_width, new_height = self._apply_scale(new_width, new_height, scale)
return F.resize(img, (new_height, new_width), self.interpolation, antialias=self.antialias)
class ImageTransform:
def __init__(
self,
max_image_size,
min_image_size,
image_stride,
max_pixels=14*14*9*1024,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5]
):
self.stride = image_stride
self.resize_transform = MaxLongEdgeMinShortEdgeResize(
max_size=max_image_size,
min_size=min_image_size,
stride=image_stride,
max_pixels=max_pixels,
)
self.to_tensor_transform = transforms.ToTensor()
self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True)
def __call__(self, img, img_num=1):
img = self.resize_transform(img, img_num=img_num)
img = self.to_tensor_transform(img)
img = self.normalize_transform(img)
return img
class VideoTransform:
def __init__(
self,
resolution=640,
mode="area",
divisible_crop_size=16,
aspect_ratios=("21:9", "16:9", "4:3", "1:1", "3:4", "9:16"),
stride_spatial=16,
stride_temporal=4,
mean=0.5,
std=0.5,
**kwargs
):
self.transform = Compose(
[
NaResize(
resolution=resolution,
mode=mode,
downsample_only=True,
stride=stride_spatial,
# NOTE: aspect_ratios are only for `bucket` resize.
aspect_ratios=aspect_ratios,
),
DivisibleCrop(divisible_crop_size),
Normalize(mean, std),
Rearrange("t c h w -> c t h w"),
]
)
# self.stride = divisible_crop_size if isinstance(divisible_crop_size, int) else divisible_crop_size[0]
self.stride_spatial = stride_spatial
self.stride_temporal = stride_temporal
def __call__(self, video):
return self.transform(video)
class VisualTransform:
def __init__(
self,
max_frame_size,
min_frame_size,
image_stride,
max_pixels=14*14*9*1024,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5]
):
self.stride = image_stride
self.resize_transform = MaxLongEdgeMinShortEdgeResize(
max_size=max_frame_size,
min_size=min_frame_size,
stride=image_stride,
max_pixels=max_pixels,
)
self.to_tensor_transform = transforms.ToTensor()
self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True)
def _process_single(self, img, img_num=1):
img = self.resize_transform(img, img_num=img_num)
img = self.to_tensor_transform(img)
img = self.normalize_transform(img)
return img
def __call__(self, img, img_num=1):
# --- 视频序列处理 ---
if isinstance(img, (list, tuple)):
# List of PIL.Image or tensors
out = torch.stack([self._process_single(frame, img_num=img_num) for frame in img]) # [T, C, H, W]
out = out.permute(1, 0, 2, 3) # [C, T, H, W]
return out
elif isinstance(img, np.ndarray) and img.ndim == 4:
# numpy array: [T, H, W, C] or [T, C, H, W]
frames = [img[i] for i in range(img.shape[0])]
processed_frames = [self._process_single(Image.fromarray(frame) if frame.shape[-1] in [3, 4] else frame, img_num=img_num)
for frame in frames]
out = torch.stack(processed_frames) # [T, C, H, W]
out = out.permute(1, 0, 2, 3) # [C, T, H, W]
return out
elif isinstance(img, torch.Tensor) and img.ndim == 4:
# torch tensor: [T, C, H, W] or [T, H, W, C]
frames = [img[i] for i in range(img.shape[0])]
processed_frames = [self._process_single(frame, img_num=img_num) for frame in frames]
out = torch.stack(processed_frames) # [T, C, H, W]
out = out.permute(1, 0, 2, 3) # [C, T, H, W]
return out
else:
# 单帧
return self._process_single(img, img_num=img_num)