| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Tokenization classes and Image processor class, Processor class for Ernie_45T_VL.""" |
|
|
| import copy |
| import io |
| import os |
| import math |
| import random |
| import requests |
| import base64 |
| import datetime |
| import hashlib |
| import threading |
| import uuid |
| import decord |
| from shutil import copyfile |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| from PIL import Image, ImageDraw, ImageFont |
| from PIL.ExifTags import TAGS |
| from collections import defaultdict |
| from pathlib import Path |
| from tempfile import NamedTemporaryFile as ntf |
|
|
| import sentencepiece as spm |
| from transformers.tokenization_utils import PreTrainedTokenizer |
| from transformers.tokenization_utils_base import ( |
| PaddingStrategy, |
| TextInput, |
| ) |
| from transformers.utils import TensorType, logging |
| from transformers.video_utils import VideoInput |
| from transformers.processing_utils import ProcessorMixin |
| from transformers.feature_extraction_utils import BatchFeature |
| from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
| from transformers.image_transforms import ( |
| convert_to_rgb, |
| normalize, |
| rescale, |
| resize, |
| to_channel_dimension_format, |
| ) |
| from transformers.image_utils import ( |
| OPENAI_CLIP_MEAN, |
| OPENAI_CLIP_STD, |
| ChannelDimension, |
| ImageInput, |
| PILImageResampling, |
| get_image_size, |
| infer_channel_dimension_format, |
| is_valid_image, |
| make_list_of_images, |
| to_numpy_array, |
| valid_images, |
| ) |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class Ernie4_5_VLTokenizer(PreTrainedTokenizer): |
| """ |
| Ernie4_5_VLTokenizer |
| """ |
|
|
| vocab_files_names = { |
| "vocab_file": "tokenizer.model", |
| } |
| |
| model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"] |
| |
| padding_side = "right" |
|
|
| def __init__( |
| self, |
| vocab_file, |
| bos_token="<s>", |
| cls_token="<cls>", |
| eos_token="</s>", |
| mask_token="<mask:0>", |
| pad_token="<pad>", |
| sep_token="<sep>", |
| unk_token="<unk>", |
| additional_special_tokens=None, |
| **kwargs, |
| ): |
| """ |
| Initialize the Ernie4_5_VLTokenizer |
| |
| Args: |
| vocab_file (str): Path to the tokenizer vocabulary model. |
| bos_token (str, optional): The beginning of sequence token. Defaults to `"<s>"`. |
| cls_token (str, optional): The classifier token. Defaults to `"<cls>"`. |
| eos_token (str, optional): The end of sequence token. Defaults to `"</s>"`. |
| mask_token (str, optional): The masking token. Defaults to `"<mask:0>"`. |
| pad_token (str, optional): The padding token. Defaults to `"<pad>"`. |
| sep_token (str, optional): The separation token. Defaults to `"<sep>"`. |
| unk_token (str, optional): The unknown tokens symbol. Defaults to `"<unk>"`. |
| additional_special_tokens (List[str], optional): Additional special tokens to use. |
| Defaults to `["<mask:1>", "<mask:7>"]`. |
| **kwargs (dict): Additional keyword arguments passed along to the superclass. |
| """ |
|
|
| |
| self.vocab_file = vocab_file |
| |
| self.sp_model = spm.SentencePieceProcessor() |
| |
| self.sp_model.Load(vocab_file) |
|
|
| |
| if additional_special_tokens is None: |
| additional_special_tokens = ["<mask:1>", "<mask:7>"] |
| super().__init__( |
| bos_token=bos_token, |
| cls_token=cls_token, |
| eos_token=eos_token, |
| mask_token=mask_token, |
| pad_token=pad_token, |
| sep_token=sep_token, |
| unk_token=unk_token, |
| additional_special_tokens=additional_special_tokens, |
| **kwargs, |
| ) |
|
|
| @property |
| def space_token(self): |
| """Return the space token""" |
| return "<mask:1>" |
|
|
| @property |
| def space_token_id(self): |
| """Return the ID of the space token""" |
| return self.sp_model.piece_to_id("<mask:1>") |
|
|
| @property |
| def gend_token(self): |
| """Return the gender token""" |
| return "<mask:7>" |
|
|
| @property |
| def gend_token_id(self): |
| """Return the ID of the gender token""" |
| return self.sp_model.piece_to_id("<mask:7>") |
|
|
| @property |
| def im_start_id(self): |
| """Return the ID of the image start token""" |
| return self.sp_model.piece_to_id("<|im_start|>") |
|
|
| @property |
| def im_end_id(self): |
| """Return the ID of the image end token""" |
| return self.sp_model.piece_to_id("<|im_end|>") |
|
|
| @property |
| def vocab_size(self): |
| """Return the size of the vocabulary""" |
| return self.sp_model.vocab_size() |
|
|
| def get_vocab(self): |
| """Return the vocabulary as a dictionary mapping tokens to IDs""" |
| vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} |
| vocab.update(self.added_tokens_encoder) |
| return vocab |
|
|
| def _tokenize(self, text): |
| """Tokenize the input text into pieces""" |
| return self.sp_model.encode_as_pieces(text) |
|
|
| def _convert_token_to_id(self, token): |
| """Convert a token to its corresponding ID""" |
| return self.sp_model.piece_to_id(token) |
|
|
| def _convert_id_to_token(self, id): |
| """Convert an ID to its corresponding token""" |
| return self.sp_model.id_to_piece(id) |
|
|
| def convert_tokens_to_string(self, tokens): |
| """Convert a sequence of tokens back to a string""" |
| current_sub_tokens = [] |
| out_string = "" |
|
|
| for token in tokens: |
| |
| if token in self.all_special_tokens: |
| out_string += self.sp_model.decode(current_sub_tokens) + token |
| current_sub_tokens = [] |
| else: |
| current_sub_tokens.append(token) |
|
|
| |
| out_string += self.sp_model.decode(current_sub_tokens) |
| return out_string |
|
|
| def prepare_for_model(self, *args, **kwargs): |
| """Prepare the tokenized inputs for the model""" |
| |
| if "add_special_tokens" in kwargs: |
| kwargs.pop("add_special_tokens") |
| return super().prepare_for_model(*args, **kwargs) |
|
|
| def save_vocabulary( |
| self, save_directory, filename_prefix: Optional[str] = None |
| ) -> Tuple[str]: |
| """ |
| Save the vocabulary and special tokens file to a directory. |
| |
| Args: |
| save_directory (`str`): The directory to save the vocabulary to |
| filename_prefix (`str`, optional): Prefix to add to the filename |
| |
| Returns: |
| `Tuple(str)`: Paths to the saved files |
| """ |
| if not os.path.isdir(save_directory): |
| logger.error(f"Vocabulary path ({save_directory}) should be a directory") |
| return |
|
|
| |
| out_vocab_file = os.path.join( |
| save_directory, |
| (filename_prefix + "-" if filename_prefix else "") |
| + self.vocab_files_names["vocab_file"], |
| ) |
|
|
| |
| if os.path.abspath(self.vocab_file) != os.path.abspath( |
| out_vocab_file |
| ) and os.path.isfile(self.vocab_file): |
| copyfile(self.vocab_file, out_vocab_file) |
| elif not os.path.isfile(self.vocab_file): |
| with open(out_vocab_file, "wb") as fi: |
| content_spiece_model = self.sp_model.serialized_model_proto() |
| fi.write(content_spiece_model) |
|
|
| return (out_vocab_file,) |
|
|
| def _decode(self, *args, **kwargs): |
| """Decode token_id back to text""" |
| |
| kwargs.pop("clean_up_tokenization_spaces", None) |
| kwargs.pop("spaces_between_special_tokens", None) |
|
|
| |
| return super()._decode( |
| *args, |
| **kwargs, |
| clean_up_tokenization_spaces=False, |
| spaces_between_special_tokens=False, |
| ) |
|
|
| def _pad( |
| self, |
| encoded_inputs: Dict, |
| max_length: Optional[int] = None, |
| padding_strategy=PaddingStrategy.DO_NOT_PAD, |
| pad_to_multiple_of: Optional[int] = None, |
| return_attention_mask: Optional[bool] = None, |
| **kwargs |
| ) -> dict: |
| """Pad the encoded inputs to the specified length""" |
| if return_attention_mask is None: |
| return_attention_mask = "attention_mask" in self.model_input_names |
| if return_attention_mask: |
| required_input = encoded_inputs[self.model_input_names[0]] |
| if padding_strategy == PaddingStrategy.LONGEST: |
| max_length = len(required_input) |
|
|
| |
| if ( |
| max_length is not None |
| and pad_to_multiple_of is not None |
| and (max_length % pad_to_multiple_of != 0) |
| ): |
| max_length = ( |
| (max_length // pad_to_multiple_of) + 1 |
| ) * pad_to_multiple_of |
|
|
| |
| needs_to_be_padded = ( |
| padding_strategy != PaddingStrategy.DO_NOT_PAD |
| and len(required_input) != max_length |
| ) |
|
|
| |
| if ( |
| "attention_mask" in encoded_inputs |
| and encoded_inputs["attention_mask"] is not None |
| ): |
| attention_mask = encoded_inputs.pop("attention_mask") |
| if isinstance(attention_mask, torch.Tensor): |
| attention_mask = attention_mask.numpy() |
| elif isinstance(attention_mask, list): |
| attention_mask = np.array(attention_mask) |
| elif not isinstance(attention_mask, np.ndarray): |
| raise ValueError( |
| f"Unexpected type {type(attention_mask)} of attention_mask, " |
| ) |
| else: |
| |
| attention_mask = np.tril( |
| np.ones((len(required_input), len(required_input)), dtype=np.int64) |
| ) |
| attention_mask = np.expand_dims(attention_mask, axis=0) |
|
|
| |
| if needs_to_be_padded: |
| difference = max_length - len(required_input) |
| if self.padding_side == "right": |
| if attention_mask.ndim == 1: |
| pad_width = [(0, difference)] |
| else: |
| pad_width = [(0, 0), (0, difference), (0, difference)] |
| elif self.padding_side == "left": |
| if attention_mask.ndim == 1: |
| pad_width = [(difference, 0)] |
| else: |
| pad_width = [(0, 0), (difference, 0), (difference, 0)] |
| else: |
| raise ValueError( |
| "Invalid padding strategy:" + str(self.padding_side) |
| ) |
|
|
| attention_mask = np.pad( |
| attention_mask, |
| pad_width=pad_width, |
| mode="constant", |
| constant_values=0, |
| ) |
|
|
| |
| encoded_inputs = super()._pad( |
| encoded_inputs, |
| max_length, |
| padding_strategy=padding_strategy, |
| pad_to_multiple_of=pad_to_multiple_of, |
| return_attention_mask=False, |
| ) |
|
|
| |
| if return_attention_mask: |
| encoded_inputs["attention_mask"] = attention_mask.tolist() |
|
|
| return encoded_inputs |
|
|
|
|
| def round_by_factor(number: int, factor: int) -> int: |
| """Returns the closest integer to 'number' that is divisible by 'factor'.""" |
| return round(number / factor) * factor |
|
|
|
|
| def ceil_by_factor(number: int, factor: int) -> int: |
| """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" |
| return math.ceil(number / factor) * factor |
|
|
|
|
| def floor_by_factor(number: int, factor: int) -> int: |
| """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" |
| return math.floor(number / factor) * factor |
|
|
|
|
| def smart_resize( |
| height: int, |
| width: int, |
| factor: int = 28, |
| min_pixels: int = 4 * 28 * 28, |
| max_pixels: int = 16384 * 28 * 28, |
| ): |
| """ |
| Rescales the image so that the following conditions are met: |
| |
| 1. Both dimensions (height and width) are divisible by 'factor'. |
| |
| 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. |
| |
| 3. The aspect ratio of the image is maintained as closely as possible. |
| """ |
| MAX_RATIO = 200 |
| if max(height, width) / min(height, width) > MAX_RATIO: |
| if height > width: |
| new_width = max(factor, round_by_factor(width, factor)) |
| new_height = floor_by_factor(new_width * MAX_RATIO, factor) |
| else: |
| new_height = max(factor, round_by_factor(height, factor)) |
| new_width = floor_by_factor(new_height * MAX_RATIO, factor) |
|
|
| logger.info( |
| f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)},\ |
| resize to {max(new_height, new_width) / min(new_height, new_width)}" |
| ) |
|
|
| height = new_height |
| width = new_width |
|
|
| h_bar = max(factor, round_by_factor(height, factor)) |
| w_bar = max(factor, round_by_factor(width, factor)) |
| if h_bar * w_bar > max_pixels: |
| beta = math.sqrt((height * width) / max_pixels) |
| h_bar = floor_by_factor(height / beta, factor) |
| w_bar = floor_by_factor(width / beta, factor) |
| elif h_bar * w_bar < min_pixels: |
| beta = math.sqrt(min_pixels / (height * width)) |
| h_bar = ceil_by_factor(height * beta, factor) |
| w_bar = ceil_by_factor(width * beta, factor) |
|
|
| if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels: |
| raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}") |
|
|
| return h_bar, w_bar |
|
|
|
|
| def is_scaled_image(image: np.ndarray) -> bool: |
| """ |
| Checks to see whether the pixel values have already been rescaled to [0, 1]. |
| """ |
| if image.dtype == np.uint8: |
| return False |
|
|
| |
| return np.min(image) >= 0 and np.max(image) <= 1 |
|
|
|
|
| def make_batched_images(images) -> List[List[ImageInput]]: |
| """ |
| Accepts images in list or nested list format, and makes a list of images for preprocessing. |
| |
| Args: |
| images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): |
| The input image. |
| |
| Returns: |
| list: A list of images. |
| """ |
| if ( |
| isinstance(images, (list, tuple)) |
| and isinstance(images[0], (list, tuple)) |
| and is_valid_image(images[0][0]) |
| ): |
| return [img for img_list in images for img in img_list] |
|
|
| elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): |
| return images |
|
|
| elif is_valid_image(images): |
| return [images] |
|
|
| raise ValueError(f"Could not make batched images from {images}") |
|
|
|
|
| |
| def make_batched_videos(videos) -> List[VideoInput]: |
| """dummy""" |
| if ( |
| isinstance(videos, (list, tuple)) |
| and isinstance(videos[0], (list, tuple)) |
| and is_valid_image(videos[0][0]) |
| ): |
| return videos |
|
|
| elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): |
| if isinstance(videos[0], Image.Image): |
| return [videos] |
| elif len(videos[0].shape) == 4: |
| return [list(video) for video in videos] |
|
|
| elif is_valid_image(videos) and len(videos.shape) == 4: |
| return [list(videos)] |
|
|
| raise ValueError(f"Could not make batched video from {videos}") |
|
|
|
|
| class Ernie4_5_VLImageProcessor(BaseImageProcessor): |
| r""" |
| Constructs a adaptive image processor that dynamically resizes images based on the original images. |
| |
| Args: |
| do_resize (`bool`, *optional*, defaults to `True`): |
| Whether to resize the image's (height, width) dimensions. |
| resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): |
| Resampling filter to use when resizing the image. |
| do_rescale (`bool`, *optional*, defaults to `True`): |
| Whether to rescale the image by the specified scale `rescale_factor`. |
| rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): |
| Scale factor to use if rescaling the image. |
| do_normalize (`bool`, *optional*, defaults to `True`): |
| Whether to normalize the image. |
| image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): |
| Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. |
| image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): |
| Standard deviation to use if normalizing the image. This is a float or list of floats for each channel |
| in the image. |
| do_convert_rgb (`bool`, *optional*, defaults to `True`): |
| Whether to convert the image to RGB. |
| min_pixels (`int`, *optional*, defaults to `56 * 56`): |
| The min pixels of the image to resize the image. |
| max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`): |
| The max pixels of the image to resize the image. |
| patch_size (`int`, *optional*, defaults to 14): |
| The spacial patch size of the vision encoder. |
| temporal_conv_size (`int`, *optional*, defaults to 2): |
| The temporal conv size in resampler. |
| merge_size (`int`, *optional*, defaults to 2): |
| The merge size of the vision encoder to llm encoder. |
| """ |
|
|
| model_input_names = [ |
| "pixel_values", |
| "image_grid_thw", |
| "pixel_values_videos", |
| "video_grid_thw", |
| ] |
|
|
| def __init__( |
| self, |
| do_resize: bool = True, |
| resample: PILImageResampling = PILImageResampling.BICUBIC, |
| do_rescale: bool = True, |
| rescale_factor: Union[float, List[float]] = 1 / 255, |
| do_normalize: bool = True, |
| image_mean: Optional[Union[float, List[float]]] = None, |
| image_std: Optional[Union[float, List[float]]] = None, |
| do_convert_rgb: bool = True, |
| min_pixels: int = 56 * 56, |
| max_pixels: int = 28 * 28 * 1280, |
| patch_size: int = 14, |
| temporal_conv_size: int = 2, |
| merge_size: int = 2, |
| **kwargs, |
| ) -> None: |
| """init""" |
| super().__init__(**kwargs) |
| self.do_resize = do_resize |
| self.resample = resample |
| self.do_rescale = do_rescale |
| self.rescale_factor = rescale_factor |
| self.do_normalize = do_normalize |
| self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN |
| self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD |
| self.min_pixels = min_pixels |
| self.max_pixels = max_pixels |
| self.patch_size = patch_size |
| self.temporal_conv_size = temporal_conv_size |
| self.merge_size = merge_size |
| self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels} |
| self.do_convert_rgb = do_convert_rgb |
|
|
| def set_pixels(self, min_pixels=None, max_pixels=None, msg=""): |
| """set_pixels""" |
| if min_pixels is not None: |
| assert ( |
| isinstance(min_pixels, int) and min_pixels >= 0 |
| ), "min_pixels must be positive int" |
| logger.info( |
| f"{msg} Ernie4_5_VLImageProcessor set min_pixels = {min_pixels}" |
| ) |
| self.min_pixels = min_pixels |
| self.size["min_pixels"] = int(min_pixels) |
| if max_pixels is not None: |
| assert ( |
| isinstance(max_pixels, int) and max_pixels > 0 |
| ), "max_pixels must be positive int" |
| logger.info( |
| f"{msg} Ernie4_5_VLImageProcessor set max_pixels = {max_pixels}" |
| ) |
| self.max_pixels = max_pixels |
| self.size["max_pixels"] = int(max_pixels) |
|
|
| def get_smarted_resize(self, height, width, min_pixels=None, max_pixels=None): |
| """dummy""" |
| actual_min_pixels = min_pixels if min_pixels is not None else self.min_pixels |
| actual_max_pixels = max_pixels if max_pixels is not None else self.max_pixels |
| resized_height, resized_width = smart_resize( |
| height, |
| width, |
| factor=self.patch_size * self.merge_size, |
| min_pixels=actual_min_pixels, |
| max_pixels=actual_max_pixels, |
| ) |
| return (resized_height, resized_width), ( |
| resized_height // self.patch_size, |
| resized_width // self.patch_size, |
| ) |
|
|
| def _preprocess( |
| self, |
| images: Union[ImageInput, VideoInput], |
| do_resize: bool = True, |
| resample: PILImageResampling = None, |
| do_rescale: bool = True, |
| rescale_factor: float = 1 / 255, |
| do_normalize: bool = True, |
| image_mean: Optional[Union[float, List[float]]] = None, |
| image_std: Optional[Union[float, List[float]]] = None, |
| do_convert_rgb: bool = False, |
| data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, |
| input_data_format: Optional[Union[str, ChannelDimension]] = None, |
| predetermined_grid_thw=None, |
| ): |
| """ |
| Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`. |
| |
| Args: |
| images (`ImageInput` or `VideoInput`): |
| Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. |
| If pixel values range from 0 to 1, set `do_rescale=False`. |
| do_resize (`bool`, *optional*, defaults to `self.do_resize`): |
| Whether to resize the image. |
| resample (`PILImageResampling`, *optional*, defaults to `self.resample`): |
| Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums. |
| do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): |
| Whether to rescale the image. |
| rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): |
| Scale factor to use if rescaling the image. |
| do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): |
| Whether to normalize the image. |
| image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): |
| Mean to use if normalizing the image. |
| Can be a float or a list of floats corresponding to the number of channels in the image. |
| image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): |
| Standard deviation to use if normalizing the image. |
| Can be a float or a list of floats corresponding to the number of channels in the image. |
| do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): |
| Whether to convert the image to RGB. |
| data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): |
| The channel dimension format for the output image. Can be one of: |
| - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. |
| - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. |
| - Unset: Use the channel dimension format of the input image. |
| input_data_format (`ChannelDimension` or `str`, *optional*): |
| The channel dimension format for the input image. Can be one of: |
| - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. |
| - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. |
| - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. |
| - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. |
| """ |
| images = make_list_of_images(images) |
|
|
| if do_convert_rgb: |
| images = [convert_to_rgb(image) for image in images] |
|
|
| |
| images = [to_numpy_array(image) for image in images] |
|
|
| if is_scaled_image(images[0]) and do_rescale: |
| logger.warning_once( |
| "It looks like you are trying to rescale already rescaled images. If the input" |
| " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." |
| ) |
| if input_data_format is None: |
| |
| input_data_format = infer_channel_dimension_format(images[0]) |
|
|
| height, width = get_image_size(images[0], channel_dim=input_data_format) |
| resized_height, resized_width = height, width |
| processed_images = [] |
|
|
| if predetermined_grid_thw is not None: |
| assert len(predetermined_grid_thw) == len( |
| images |
| ), f"len(predetermined_grid_thw) {len(predetermined_grid_thw)} == len(images) {len(images)}" |
|
|
| for img_idx, image in enumerate(images): |
| if do_resize: |
| if predetermined_grid_thw is not None: |
| (resized_height, resized_width) = predetermined_grid_thw[img_idx] |
| resized_height *= self.patch_size |
| resized_width *= self.patch_size |
| else: |
| resized_height, resized_width = smart_resize( |
| height, |
| width, |
| factor=self.patch_size * self.merge_size, |
| min_pixels=self.min_pixels, |
| max_pixels=self.max_pixels, |
| ) |
|
|
| image = resize( |
| image, |
| size=(resized_height, resized_width), |
| resample=resample, |
| data_format=input_data_format, |
| ) |
| if do_rescale: |
| image = rescale( |
| image, scale=rescale_factor, data_format=input_data_format |
| ) |
|
|
| if do_normalize: |
| image = normalize( |
| image=image, |
| mean=image_mean, |
| std=image_std, |
| data_format=input_data_format, |
| ) |
|
|
| image = to_channel_dimension_format( |
| image, data_format, input_channel_dim=input_data_format |
| ) |
|
|
| processed_images.append(image) |
| patches = np.array(processed_images) |
| if data_format == ChannelDimension.LAST: |
| patches = patches.transpose([0, 3, 1, 2]) |
|
|
| channel = patches.shape[1] |
| grid_t = patches.shape[0] |
| grid_h, grid_w = ( |
| resized_height // self.patch_size, |
| resized_width // self.patch_size, |
| ) |
| patches = patches.reshape( |
| [ |
| grid_t, |
| channel, |
| grid_h // self.merge_size, |
| self.merge_size, |
| self.patch_size, |
| grid_w // self.merge_size, |
| self.merge_size, |
| self.patch_size, |
| ] |
| ) |
| |
| patches = patches.transpose([0, 2, 5, 3, 6, 1, 4, 7]) |
|
|
| flatten_patches = patches.reshape( |
| [grid_t * grid_h * grid_w, channel * self.patch_size * self.patch_size] |
| ) |
|
|
| return flatten_patches, (grid_t, grid_h, grid_w) |
|
|
| def preprocess( |
| self, |
| images: ImageInput, |
| videos: VideoInput = None, |
| do_resize: bool = True, |
| size: Optional[Union[int, List[int]]] = None, |
| resample: PILImageResampling = None, |
| do_rescale: bool = True, |
| rescale_factor: float = 1 / 255, |
| do_normalize: bool = True, |
| image_mean: Optional[Union[float, List[float]]] = None, |
| image_std: Optional[Union[float, List[float]]] = None, |
| do_convert_rgb: bool = False, |
| return_tensors: Optional[Union[str, TensorType]] = None, |
| data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, |
| input_data_format: Optional[Union[str, ChannelDimension]] = None, |
| predetermined_grid_thw=None, |
| ): |
| """ |
| Args: |
| images (`ImageInput`): |
| Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If |
| passing in images with pixel values between 0 and 1, set `do_rescale=False`. |
| videos (`VideoInput`): |
| Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If |
| passing in videos with pixel values between 0 and 1, set `do_rescale=False`. |
| do_resize (`bool`, *optional*, defaults to `self.do_resize`): |
| Whether to resize the image. |
| size (`Dict[str, int]`, *optional*, defaults to `self.size`): |
| Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with |
| the longest edge resized to keep the input aspect ratio. |
| resample (`int`, *optional*, defaults to `self.resample`): |
| Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only |
| has an effect if `do_resize` is set to `True`. |
| do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): |
| Whether to rescale the image. |
| rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): |
| Rescale factor to rescale the image by if `do_rescale` is set to `True`. |
| do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): |
| Whether to normalize the image. |
| image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): |
| Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. |
| image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): |
| Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to |
| `True`. |
| do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): |
| Whether to convert the image to RGB. |
| return_tensors (`str` or `TensorType`, *optional*): |
| The type of tensors to return. Can be one of: |
| - Unset: Return a list of `np.ndarray`. |
| - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. |
| - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. |
| data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): |
| The channel dimension format for the output image. Can be one of: |
| - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. |
| - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. |
| - Unset: Use the channel dimension format of the input image. |
| input_data_format (`ChannelDimension` or `str`, *optional*): |
| The channel dimension format for the input image. If unset, the channel dimension format is inferred |
| from the input image. Can be one of: |
| - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. |
| - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. |
| - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. |
| |
| """ |
| do_resize = do_resize if do_resize is not None else self.do_resize |
| size = size if size is not None else self.size |
| resample = resample if resample is not None else self.resample |
| do_rescale = do_rescale if do_rescale is not None else self.do_rescale |
| rescale_factor = ( |
| rescale_factor if rescale_factor is not None else self.rescale_factor |
| ) |
| do_normalize = do_normalize if do_normalize is not None else self.do_normalize |
| image_mean = image_mean if image_mean is not None else self.image_mean |
| image_std = image_std if image_std is not None else self.image_std |
| do_convert_rgb = ( |
| do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb |
| ) |
|
|
| if images is not None: |
| images = make_batched_images(images) |
|
|
| if images is not None and not valid_images(images): |
| raise ValueError( |
| "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " |
| "torch.Tensor." |
| ) |
|
|
| data = {} |
| if images is not None: |
| pixel_values, vision_grid_thws = [], [] |
| for img_idx, image in enumerate(images): |
| if predetermined_grid_thw is not None: |
| predetermined_grid_thw_one = [predetermined_grid_thw[img_idx]] |
| else: |
| predetermined_grid_thw_one = None |
| patches, image_grid_thw = self._preprocess( |
| image, |
| do_resize=do_resize, |
| resample=resample, |
| do_rescale=do_rescale, |
| rescale_factor=rescale_factor, |
| do_normalize=do_normalize, |
| image_mean=image_mean, |
| image_std=image_std, |
| data_format=data_format, |
| do_convert_rgb=do_convert_rgb, |
| input_data_format=input_data_format, |
| predetermined_grid_thw=predetermined_grid_thw_one, |
| ) |
| pixel_values.extend(patches) |
| vision_grid_thws.append(image_grid_thw) |
| pixel_values = np.array(pixel_values) |
| vision_grid_thws = np.array(vision_grid_thws) |
| data.update( |
| {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws} |
| ) |
|
|
| if videos is not None: |
| videos = make_batched_videos(videos) |
| pixel_values, vision_grid_thws = [], [] |
| for images in videos: |
| patches, video_grid_thw = self._preprocess( |
| images, |
| do_resize=do_resize, |
| resample=resample, |
| do_rescale=do_rescale, |
| rescale_factor=rescale_factor, |
| do_normalize=do_normalize, |
| image_mean=image_mean, |
| image_std=image_std, |
| data_format=data_format, |
| do_convert_rgb=do_convert_rgb, |
| input_data_format=input_data_format, |
| predetermined_grid_thw=predetermined_grid_thw, |
| ) |
| pixel_values.extend(patches) |
| vision_grid_thws.append(video_grid_thw) |
| pixel_values = np.array(pixel_values) |
| vision_grid_thws = np.array(vision_grid_thws) |
|
|
| data.update( |
| { |
| "pixel_values_videos": pixel_values, |
| "video_grid_thw": vision_grid_thws, |
| } |
| ) |
|
|
| return BatchFeature(data=data, tensor_type=return_tensors) |
|
|
|
|
| RAW_VIDEO_DIR = "./download_tmp/raw_video/" |
| RAW_IMAGE_DIR = "./download_tmp/raw_images/" |
| EXTRACTED_FRAME_DIR = "./download_tmp/extracted_frames/" |
| TMP_DIR = "./download_tmp/upload_tmp/" |
|
|
| FONT_PATH = os.path.join(Path(__file__).parent.absolute(), "Roboto-Regular.ttf") |
| if not os.path.exists(FONT_PATH): |
| ttf = requests.get("https://paddlenlp.bj.bcebos.com/vision-language-models/materials/Roboto-Regular.ttf") |
| open(FONT_PATH, "wb").write(ttf.content) |
|
|
|
|
| def is_gif(data: bytes) -> bool: |
| """ |
| check if a bytes is a gif based on the magic head |
| """ |
| return data[:6] in (b"GIF87a", b"GIF89a") |
|
|
|
|
| class VideoReaderWrapper(decord.VideoReader): |
| """ |
| Solving memory leak bug |
| |
| https://github.com/dmlc/decord/issues/208 |
| """ |
|
|
| def __init__(self, video_path, *args, **kwargs): |
| with ntf(delete=True, suffix=".gif") as gif_file: |
| gif_input = None |
| self.original_file = None |
| if isinstance(video_path, str): |
| self.original_file = video_path |
| if video_path.lower().endswith(".gif"): |
| gif_input = video_path |
| elif isinstance(video_path, bytes): |
| if is_gif(video_path): |
| gif_file.write(video_path) |
| gif_input = gif_file.name |
| elif isinstance(video_path, io.BytesIO): |
| video_path.seek(0) |
| tmp_bytes = video_path.read() |
| video_path.seek(0) |
| if is_gif(tmp_bytes): |
| gif_file.write(tmp_bytes) |
| gif_input = gif_file.name |
|
|
| if gif_input is not None: |
| try: |
| |
| import moviepy.editor as mp |
| except: |
| |
| import moviepy as mp |
| clip = mp.VideoFileClip(gif_input) |
| mp4_file = ntf(delete=False, suffix=".mp4") |
| clip.write_videofile(mp4_file.name, logger=None) |
| clip.close() |
| video_path = mp4_file.name |
| self.original_file = video_path |
|
|
| super().__init__(video_path, *args, **kwargs) |
| self.seek(0) |
|
|
| def __getitem__(self, key): |
| frames = super().__getitem__(key) |
| self.seek(0) |
| return frames |
|
|
| def __del__(self): |
| if self.original_file and os.path.exists(self.original_file): |
| os.remove(self.original_file) |
|
|
|
|
| def get_filename(url=None): |
| """ |
| Get Filename |
| """ |
| if url is None: |
| return str(uuid.uuid4()).replace("-", "") |
| t = datetime.datetime.now() |
| if not isinstance(url, bytes): |
| url = url.encode("utf-8") |
|
|
| md5_hash = hashlib.md5(url).hexdigest() |
| pid = os.getpid() |
| tid = threading.get_ident() |
|
|
| |
| image_filname = f"{t.year}-{t.month:02d}-{t.day:02d}-{pid}-{tid}-{md5_hash}" |
| return image_filname |
|
|
|
|
| def file_download(url, download_dir, save_to_disk=False, retry=0, retry_interval=3): |
| """ |
| Description: Download url, if url is PIL, return directly |
| Args: |
| url(str, PIL): http/local path/io.Bytes, note that io.Bytes is the image byte stream |
| download_path: when save_to_disk=True, return the saved address |
| save_to_disk: whether to save in the local path |
| """ |
|
|
| if isinstance(url, Image.Image): |
| return url |
| elif isinstance(url, VideoReaderWrapper): |
| return url |
| elif url.startswith("http"): |
| response = requests.get(url) |
| bytes_data = response.content |
| elif os.path.isfile(url): |
| if save_to_disk: |
| return url |
| bytes_data = open(url, "rb").read() |
| else: |
| bytes_data = base64.b64decode(url) |
| if not save_to_disk: |
| return bytes_data |
|
|
| download_path = os.path.join(download_dir, get_filename(url)) |
| Path(download_path).parent.mkdir(parents=True, exist_ok=True) |
| with open(download_path, "wb") as f: |
| f.write(bytes_data) |
| return download_path |
|
|
|
|
| def get_downloadable( |
| url, download_dir=RAW_VIDEO_DIR, save_to_disk=False, retry=0, retry_interval=3 |
| ): |
| """download video and store it in the disk |
| |
| return downloaded **path** if save_to_disk is set to true |
| return downloaded **bytes** if save_to_disk is set to false |
| """ |
|
|
| if not os.path.exists(download_dir): |
| os.makedirs(download_dir) |
| downloaded_path = file_download( |
| url, |
| download_dir, |
| save_to_disk=save_to_disk, |
| retry=retry, |
| retry_interval=retry_interval, |
| ) |
| return downloaded_path |
|
|
|
|
| def get_downloadable_image( |
| download_path, need_exif_info, retry_max_time=0, retry_interval=3 |
| ): |
| """ |
| Get downloadable with exif info and image processing |
| """ |
|
|
| def get_image_exif(image): |
| exif_data = image._getexif() |
| exif_info = {} |
| if exif_data is not None: |
| for tag, value in exif_data.items(): |
| tag_name = TAGS.get(tag, tag) |
| exif_info[tag_name] = value.strip() |
| return exif_info |
|
|
| def has_transparent_background(img): |
| """has_transparent_background""" |
| if img.mode in ("RGBA", "LA") or ( |
| img.mode == "P" and "transparency" in img.info |
| ): |
| |
| alpha = img.convert("RGBA").split()[-1] |
| if alpha.getextrema()[0] < 255: |
| return True |
| return False |
|
|
| def add_white_background(img): |
| """ |
| Add a white background to a transparent background image |
| """ |
| if img.mode != "RGBA": |
| img = img.convert("RGBA") |
| |
| img_white_background = Image.new("RGBA", img.size, (255, 255, 255)) |
|
|
| |
| img_white_background.paste(img, (0, 0), img) |
|
|
| return img_white_background |
|
|
| def change_I16_to_L(img): |
| """ |
| Convert image from I;16 mode to L mode |
| """ |
| |
| |
| return img.point(lambda i: i * (1 / 256)).convert("L") |
|
|
| image = get_downloadable( |
| download_path, |
| save_to_disk=False, |
| retry=retry_max_time, |
| retry_interval=retry_interval, |
| ) |
| if isinstance(image, Image.Image): |
| pil_image = image |
| else: |
| pil_image = Image.open(io.BytesIO(image)) |
| if need_exif_info: |
| try: |
| exif_info = get_image_exif(pil_image) |
| except Exception as why: |
| exif_info = {} |
| else: |
| exif_info = {} |
|
|
| try: |
| if pil_image.mode == "I;16": |
| pil_image = change_I16_to_L(pil_image) |
| if has_transparent_background(pil_image): |
| pil_image = add_white_background(pil_image) |
| except Exception as e: |
| pass |
|
|
| return pil_image.convert("RGB"), exif_info |
|
|
|
|
| def read_video_decord(video_path, save_to_disk): |
| """get reader and meta by decord""" |
| video_path = get_downloadable(video_path, save_to_disk=save_to_disk) |
| if isinstance(video_path, VideoReaderWrapper): |
| video_reader = video_path |
| else: |
| if isinstance(video_path, bytes): |
| video_path = io.BytesIO(video_path) |
| video_reader = VideoReaderWrapper(video_path, num_threads=1) |
| vlen = len(video_reader) |
| fps = video_reader.get_avg_fps() |
| duration = vlen / float(fps) |
|
|
| video_meta = {"fps": fps, "duration": duration, "num_of_frame": vlen} |
|
|
| return video_reader, video_meta, video_path |
|
|
|
|
| def get_frame_indices( |
| vlen, |
| target_frames=-1, |
| target_fps=-1, |
| frames_sample="middle", |
| fix_start=None, |
| input_fps=-1, |
| ): |
| """get_frame_indices""" |
| assert frames_sample in ["rand", "middle", "leading"] |
| if target_frames > 0: |
| assert target_fps <= 0, "target_fps must be negative if target_frames is given." |
| if target_frames > vlen: |
| acc_samples = vlen |
| logger.info( |
| f"target_frames={target_frames} is larger than video length {vlen}, " |
| f"will sample {acc_samples} frames." |
| ) |
| else: |
| acc_samples = target_frames |
| logger.debug( |
| f"sampling at target_frames={target_frames}, frames_sample={frames_sample}" |
| ) |
|
|
| |
| intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) |
| ranges = [] |
| for idx, interv in enumerate(intervals[:-1]): |
| ranges.append((interv, intervals[idx + 1] - 1)) |
| if frames_sample == "rand": |
| try: |
| frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] |
| except Exception as e: |
| frame_indices = np.random.permutation(vlen)[:acc_samples] |
| frame_indices.sort() |
| frame_indices = list(frame_indices) |
| elif fix_start is not None: |
| frame_indices = [x[0] + fix_start for x in ranges] |
| elif frames_sample == "leading": |
| frame_indices = [x[0] for x in ranges] |
| elif frames_sample == "middle": |
| frame_indices = [(x[0] + x[1]) // 2 for x in ranges] |
| else: |
| raise NotImplementedError |
|
|
| elif target_fps > 0: |
| assert ( |
| target_frames <= 0 |
| ), "target_frames must be negative if target_fps is given." |
| assert input_fps > 0, "input_fps must be provided if target_fps is given." |
| logger.info(f"sampling at fps={target_fps}, frames_sample={frames_sample}") |
| duration = float(vlen) / input_fps |
| delta = ( |
| 1 / target_fps |
| ) |
| if frames_sample == "middle": |
| frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) |
| elif frames_sample == "leading": |
| frame_seconds = np.arange(0, duration, delta) |
| if frames_sample == "rand": |
| frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) |
| rand_offset = np.random.rand(*(frame_seconds.shape)) - 0.5 |
| frame_seconds += rand_offset * delta |
| frame_indices = np.around(frame_seconds * input_fps).astype(int) |
| frame_indices = [e for e in frame_indices if e < vlen] |
|
|
| else: |
| raise ValueError( |
| "Must provide either positive target_fps or positive target_frames." |
| ) |
|
|
| return frame_indices |
|
|
|
|
| def read_frames_decord( |
| video_path, |
| video_reader, |
| video_meta, |
| target_frames=-1, |
| target_fps=-1, |
| frames_sample="middle", |
| fix_start=None, |
| save_to_disk=False, |
| cache_dir=EXTRACTED_FRAME_DIR, |
| frame_indices=None, |
| tol=10, |
| ): |
| """get frames by decord""" |
|
|
| if frame_indices is None: |
| frame_indices = get_frame_indices( |
| video_meta["num_of_frame"], |
| target_frames=target_frames, |
| target_fps=target_fps, |
| frames_sample=frames_sample, |
| fix_start=fix_start, |
| input_fps=video_meta["fps"], |
| ) |
|
|
| frames = [] |
| for frame_indice_index in range(0, len(frame_indices)): |
| frame_indice = frame_indices[frame_indice_index] |
| try: |
| frames.append(video_reader[frame_indice].asnumpy()) |
| except Exception as e: |
| logger.debug(f"encounter error when get frame: {frame_indice}, error: {e}") |
| previous_counter = 1 |
| later_counter = 1 |
| previous_after_flag = True |
| if frame_indice == 0 or frame_indice == len(video_reader) - 1: |
| cur_tol = tol * 2 |
| else: |
| cur_tol = tol |
| while previous_counter < cur_tol or later_counter < cur_tol: |
| if previous_after_flag: |
| if frame_indice - previous_counter < 0: |
| previous_counter += 1 |
| previous_after_flag = not previous_after_flag |
| continue |
| try: |
| frames.append( |
| video_reader[frame_indice - previous_counter].asnumpy() |
| ) |
| logger.info( |
| f"replace {frame_indice}-th frame with {frame_indice-previous_counter}-th frame" |
| ) |
| frame_indices[frame_indice_index] = ( |
| frame_indice - previous_counter |
| ) |
| break |
| except Exception as e: |
| previous_counter += 1 |
| else: |
| if frame_indice + later_counter >= len(video_reader): |
| later_counter += 1 |
| previous_after_flag = not previous_after_flag |
| continue |
| try: |
| frames.append( |
| video_reader[frame_indice + later_counter].asnumpy() |
| ) |
| logger.info( |
| f"replace {frame_indice}-th frame with {frame_indice+later_counter}-th frame" |
| ) |
| frame_indices[frame_indice_index] = frame_indice + later_counter |
| break |
| except Exception as e: |
| later_counter += 1 |
| previous_after_flag = not previous_after_flag |
|
|
| frames = np.stack(frames, axis=0) |
| assert len(frames) == len( |
| frame_indices |
| ), f"len(frames): {len(frames)} != len(frame_indices): {len(frame_indices)}" |
|
|
| ret = [] |
|
|
| url_sha1 = get_filename() |
| for idx, frame in enumerate(frames): |
| tmp = Image.fromarray(frame, "RGB") |
| if save_to_disk: |
| save_path = os.path.join(cache_dir, f"{url_sha1}", f"{idx}.png") |
| if not os.path.exists(os.path.dirname(save_path)): |
| os.makedirs(os.path.dirname(save_path)) |
| tmp.save(save_path) |
| tmp = save_path |
| ret.append(tmp) |
|
|
| time_stamps = [ |
| frame_idx * video_meta["duration"] / video_meta["num_of_frame"] |
| for frame_idx in frame_indices |
| ] |
|
|
| return ret, frame_indices, time_stamps |
|
|
|
|
| def render_single_image_with_timestamp( |
| image: Image, number: str, rate: float, font_path: str = FONT_PATH |
| ): |
| """ |
| Function: Renders a timestamp to the image of pil.image |
| The timestamp size is the rate of min(width, height) |
| The font color is black, the outline is white, and the outline size is 10% of the font |
| Returns an Image object |
| """ |
| draw = ImageDraw.Draw(image) |
| width, height = image.size |
| font_size = int(min(width, height) * rate) |
| outline_size = int(font_size * 0.1) |
| font = ImageFont.truetype(font_path, font_size) |
| x = 0 |
| y = 0 |
|
|
| |
| draw.text( |
| (x, y), |
| number, |
| font=font, |
| fill=(0, 0, 0), |
| stroke_width=outline_size, |
| stroke_fill=(255, 255, 255), |
| ) |
|
|
| return image |
|
|
|
|
| def timestamp_converting(time_stamp_in_seconds): |
| """ |
| convert timestamp format from seconds to hr:min:sec |
| """ |
| |
| hours = 0 |
| while time_stamp_in_seconds >= 3600: |
| hours += 1 |
| time_stamp_in_seconds -= 3600 |
| |
| mins = 0 |
| while time_stamp_in_seconds >= 60: |
| mins += 1 |
| time_stamp_in_seconds -= 60 |
| time_hours = f"{int(hours):02d}" |
| time_mins = f"{int(mins):02d}" |
| time_secs = f"{time_stamp_in_seconds:05.02f}" |
| fi_time_stamp = time_hours + ":" + time_mins + ":" + time_secs |
|
|
| return fi_time_stamp |
|
|
|
|
| def render_frame_timestamp(frame, timestamp, font_rate=0.1): |
| """ |
| Function, given a frame, render the index in order |
| Logic: render the index to the upper left corner of the image |
| frame: frame, PIL.Image object |
| timestamp: timestamp, in seconds |
| font_rate: the ratio of font size to min(wi, hei) |
| """ |
| time_stamp = "time: " + timestamp_converting(timestamp) |
| new_frame = render_single_image_with_timestamp(frame, time_stamp, font_rate) |
|
|
| return new_frame |
|
|
|
|
| IDS_TYPE_FLAG = {"text": 0, "image": 1, "video": 2, "audio": 3} |
|
|
|
|
| class Ernie4_5_VLProcessor(ProcessorMixin): |
| """ |
| Processes multimodal chat messages into model-ready inputs, |
| handling text, images, and videos with 3D positional embeddings. |
| """ |
|
|
| attributes = ["image_processor", "tokenizer"] |
| valid_kwargs = [ |
| "chat_template", |
| "spatial_conv_size", |
| "temporal_conv_size", |
| "image_min_pixels", |
| "image_max_pixels", |
| "video_min_pixels", |
| "video_max_pixels", |
| "video_target_frames", |
| "video_frames_sample", |
| "video_max_frames", |
| "video_min_frames", |
| "video_fps", |
| ] |
| image_processor_class = "AutoImageProcessor" |
| tokenizer_class = "AutoTokenizer" |
|
|
| CLS_TOKEN = "<|begin_of_sentence|>" |
| SEP_TOKEN = "<|end_of_sentence|>" |
| IMG_START = "<|IMAGE_START|>" |
| IMG_END = "<|IMAGE_END|>" |
| VID_START = "<|VIDEO_START|>" |
| VID_END = "<|VIDEO_END|>" |
|
|
| def __init__( |
| self, |
| image_processor=None, |
| tokenizer=None, |
| chat_template=None, |
| spatial_conv_size: int = 2, |
| temporal_conv_size: int = 2, |
| image_min_pixels: int = 4 * 28 * 28, |
| image_max_pixels: int = 6177 * 28 * 28, |
| video_min_pixels: int = 299 * 28 * 28, |
| video_max_pixels: int = 1196 * 28 * 28, |
| video_target_frames: int = -1, |
| video_frames_sample: str = "leading", |
| video_max_frames: int = 180, |
| video_min_frames: int = 16, |
| video_fps: int = 2, |
| **kwargs, |
| ): |
| super().__init__(image_processor, tokenizer, chat_template=chat_template) |
| self.tokenizer.ignored_index = -100 |
|
|
| |
| self.spatial_conv_size = spatial_conv_size |
| self.temporal_conv_size = temporal_conv_size |
|
|
| |
| self.image_min_pixels = image_min_pixels |
| self.image_max_pixels = image_max_pixels |
| self.video_min_pixels = video_min_pixels |
| self.video_max_pixels = video_max_pixels |
|
|
| |
| self.target_frames = video_target_frames |
| self.frames_sample = video_frames_sample |
| self.max_frames = video_max_frames |
| self.min_frames = video_min_frames |
| self.fps = video_fps |
|
|
| |
| self.cls_token = self.CLS_TOKEN |
| self.sep_token = self.SEP_TOKEN |
| self.image_start = self.IMG_START |
| self.image_end = self.IMG_END |
| self.video_start = self.VID_START |
| self.video_end = self.VID_END |
| self.image_patch_id = self.tokenizer.convert_tokens_to_ids( |
| "<|IMAGE_PLACEHOLDER|>" |
| ) |
|
|
| self.token_type_mapping = self._build_token_type_mapping() |
| self.is_training = True |
| self.role_prefixes = {"system": "", "user": "User: ", "bot": "Assistant: "} |
|
|
| def _build_token_type_mapping(self) -> Dict[Any, int]: |
| mapping = defaultdict(lambda: IDS_TYPE_FLAG["text"]) |
| for token in (self.IMG_START, self.IMG_END, self.VID_START, self.VID_END): |
| mapping[token] = IDS_TYPE_FLAG["image"] |
| mapping[self.image_patch_id] = IDS_TYPE_FLAG["image"] |
| return mapping |
|
|
| def train(self) -> None: |
| """Enable training mode (produces labels).""" |
| self.is_training = True |
|
|
| def eval(self) -> None: |
| """Enable evaluation mode (doesn't produce labels).""" |
| self.is_training = False |
|
|
| def _download_image( |
| self, |
| item: Dict, |
| ): |
| """Download image from url and resize it to the specified size.""" |
| url_info = item.get("image_url", {}) |
| url = url_info.get("url") |
| w = url_info.get("image_width", None) |
| h = url_info.get("image_height", None) |
| data = get_downloadable(url, download_dir=RAW_IMAGE_DIR, save_to_disk=False) |
|
|
| img = Image.open(io.BytesIO(data) if isinstance(data, bytes) else data) |
| if w and h: |
| img = img.resize((w, h)) |
| return img |
|
|
| def _download_video(self, item: Dict): |
| """Download video from url and resize it to the specified size.""" |
| url_info = item.get("video_url", {}) |
| url = url_info.get("url") |
|
|
| frames = self._load_and_process_video(url, item) |
|
|
| pixel_stack = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0) |
| return pixel_stack |
|
|
| def process_vision_info(self, messages: List[Dict[str, Any]]): |
| """Preprocess messages into lists of text, images, and videos.""" |
| images = [] |
| videos = [] |
|
|
| for msg in messages: |
| content_items = msg.get("content") |
| if not isinstance(content_items, list): |
| content_items = [content_items] |
|
|
| for item in content_items: |
| if item.get("type") == "image_url": |
| img = self._download_image(item) |
| images.append(img) |
| elif item.get("type") == "video_url": |
| pixel_stack = self._download_video(item) |
| videos.append(pixel_stack) |
| |
| return images, videos |
|
|
| def __call__( |
| self, |
| text: Union[str, List[str]], |
| images: List[Image.Image] = None, |
| videos: List[List[Image.Image]] = None, |
| **kwargs, |
| ) -> BatchFeature: |
| """ |
| Convert chat messages into model inputs. |
| Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels. |
| """ |
| outputs = { |
| "input_ids": [], |
| "token_type_ids": [], |
| "position_ids": [], |
| "images": [], |
| "grid_thw": [], |
| "image_type_ids": [], |
| "cur_position": 0, |
| "pic_cnt": 0, |
| "video_cnt": 0, |
| } |
| if images is None: |
| images = [] |
| if videos is None: |
| videos = [] |
| if not isinstance(text, list): |
| text = [text] |
| |
| texts = text[0] |
|
|
| new_video_seg = True |
| for text_with_image in texts.split(self.VID_START + "<|video@placeholder|>" + self.VID_END): |
| new_text_seg = True |
| if not new_video_seg: |
| self._add_video(videos[outputs["video_cnt"]], outputs) |
| for text in text_with_image.split(self.IMG_START + "<|image@placeholder|>" + self.IMG_END): |
| if not new_text_seg: |
| self._add_image(images[outputs["pic_cnt"]], outputs) |
| self._add_text(text, outputs) |
| new_text_seg = False |
| new_video_seg = False |
|
|
| for key in ["cur_position", "pic_cnt", "video_cnt"]: |
| outputs.pop(key, None) |
|
|
| outputs = self._pack_outputs(outputs) |
| for key in outputs.keys(): |
| if isinstance(outputs[key], np.ndarray): |
| if key in ["images", "grid_thw"]: |
| outputs[key] = torch.tensor(np.array(outputs[key])) |
| else: |
| outputs[key] = torch.tensor(np.array([outputs[key]])) |
|
|
| return BatchFeature(data=outputs) |
|
|
| def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None: |
| """add special token to outputs""" |
| token_id = ( |
| token |
| if isinstance(token, int) |
| else self.tokenizer.convert_tokens_to_ids(token) |
| ) |
| outputs["input_ids"].append(token_id) |
| outputs["token_type_ids"].append(self.token_type_mapping[token]) |
| pos = outputs["cur_position"] |
| outputs["position_ids"].append([pos] * 3) |
| outputs["cur_position"] += 1 |
| |
| def _add_text(self, text: str, outputs: Dict) -> None: |
| """add text to outputs""" |
| tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) |
| outputs["input_ids"].extend(tokens) |
| outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * len(tokens)) |
|
|
| start = outputs["cur_position"] |
| for i in range(len(tokens)): |
| outputs["position_ids"].append([start + i] * 3) |
| outputs["cur_position"] += len(tokens) |
|
|
| def _add_image(self, img: Image.Image, outputs: Dict) -> None: |
| """add image to outputs""" |
| outputs["pic_cnt"] += 1 |
| self._add_special_token(self.IMG_START, outputs) |
|
|
| patches_h, patches_w = self.image_processor.get_smarted_resize( |
| img.height, |
| img.width, |
| min_pixels=self.image_min_pixels, |
| max_pixels=self.image_max_pixels, |
| )[1] |
| num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2) |
|
|
| outputs["input_ids"].extend([self.image_patch_id] * num_tokens) |
| outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens) |
|
|
| pos_ids = self._compute_3d_positions( |
| 1, patches_h, patches_w, outputs["cur_position"] |
| ) |
| outputs["position_ids"].extend(pos_ids) |
| outputs["cur_position"] = np.max(pos_ids) + 1 |
|
|
| |
| ret = self.image_processor.preprocess( |
| images=[img.convert("RGB")], |
| do_normalize=False, |
| do_rescale=False, |
| predetermined_grid_thw=np.array([[patches_h, patches_w]]), |
| do_convert_rgb=True, |
| input_data_format=ChannelDimension.LAST, |
| ) |
| outputs["images"].append(ret["pixel_values"]) |
| outputs["grid_thw"].append(ret["image_grid_thw"]) |
| outputs["image_type_ids"].append(0) |
|
|
| self._add_special_token(self.IMG_END, outputs) |
|
|
| def _add_video( |
| self, pixel_stack: np.ndarray, outputs: Dict |
| ) -> None: |
| outputs["video_cnt"] += 1 |
| self._add_special_token(self.VID_START, outputs) |
|
|
| patches_h, patches_w = self.image_processor.get_smarted_resize( |
| pixel_stack.shape[1], |
| pixel_stack.shape[2], |
| min_pixels=self.video_min_pixels, |
| max_pixels=self.video_max_pixels, |
| )[1] |
| num_frames = pixel_stack.shape[0] |
| num_tokens = (num_frames * patches_h * patches_w) // ( |
| self.spatial_conv_size**2 * self.temporal_conv_size |
| ) |
|
|
| ret = self.image_processor.preprocess( |
| images=None, |
| videos=pixel_stack, |
| do_normalize=False, |
| do_rescale=False, |
| predetermined_grid_thw=np.array([[patches_h, patches_w]] * num_frames), |
| do_convert_rgb=True, |
| input_data_format=ChannelDimension.LAST, |
| ) |
| outputs["images"].append(ret["pixel_values_videos"]) |
| outputs["grid_thw"].append(ret["video_grid_thw"]) |
| outputs["image_type_ids"].extend([1] * num_frames) |
|
|
| outputs["input_ids"].extend([self.image_patch_id] * num_tokens) |
| outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens) |
|
|
| pos_ids = self._compute_3d_positions( |
| num_frames, patches_h, patches_w, outputs["cur_position"] |
| ) |
| outputs["position_ids"].extend(pos_ids) |
| outputs["cur_position"] = np.max(pos_ids) + 1 |
|
|
| self._add_special_token(self.VID_END, outputs) |
|
|
| def _load_and_process_video(self, url: str, item: Dict) -> List[Image.Image]: |
| reader, meta, path = read_video_decord(url, save_to_disk=False) |
|
|
| video_frame_args = dict() |
| video_frame_args["fps"] = item.get("fps", self.fps) |
| video_frame_args["min_frames"] = item.get("min_frames", self.min_frames) |
| video_frame_args["max_frames"] = item.get("max_frames", self.max_frames) |
| video_frame_args["target_frames"] = item.get( |
| "target_frames", self.target_frames |
| ) |
| video_frame_args["frames_sample"] = item.get( |
| "frames_sample", self.frames_sample |
| ) |
|
|
| video_frame_args = self._set_video_frame_args(video_frame_args, meta) |
|
|
| frames_data, _, timestamps = read_frames_decord( |
| path, |
| reader, |
| meta, |
| target_frames=video_frame_args["target_frames"], |
| target_fps=video_frame_args["fps"], |
| frames_sample=video_frame_args["frames_sample"], |
| save_to_disk=False, |
| ) |
|
|
| frames: List[Image.Image] = [] |
| for img_array, ts in zip(frames_data, timestamps): |
| frames.append(render_frame_timestamp(img_array, ts)) |
| |
| if len(frames) % 2 != 0: |
| frames.append(copy.deepcopy(frames[-1])) |
| return frames |
|
|
| def _set_video_frame_args(self, video_frame_args, video_meta): |
| """ |
| Set the final frame extraction parameters based on known parameters and priorities |
| """ |
| |
| if video_frame_args["target_frames"] > 0: |
| if video_frame_args["fps"] >= 0: |
| raise ValueError("fps must be negative if target_frames is given") |
| if ( |
| video_frame_args["min_frames"] > 0 |
| and video_frame_args["target_frames"] < video_frame_args["min_frames"] |
| ): |
| raise ValueError("target_frames must be larger than min_frames") |
| if ( |
| video_frame_args["max_frames"] > 0 |
| and video_frame_args["target_frames"] > video_frame_args["max_frames"] |
| ): |
| raise ValueError("target_frames must be smaller than max_frames") |
| else: |
| if video_frame_args["fps"] < 0: |
| raise ValueError( |
| "Must provide either positive target_fps or positive target_frames." |
| ) |
| |
| frames_to_extract = int(video_meta["duration"] * video_frame_args["fps"]) |
| |
| if ( |
| video_frame_args["min_frames"] > 0 |
| and video_frame_args["max_frames"] > 0 |
| and video_frame_args["min_frames"] > video_frame_args["max_frames"] |
| ): |
| raise ValueError("min_frames must be smaller than max_frames") |
| if ( |
| video_frame_args["min_frames"] > 0 |
| and frames_to_extract < video_frame_args["min_frames"] |
| ): |
| video_frame_args["target_frames"] = video_frame_args["min_frames"] |
| video_frame_args["fps"] = -1 |
| if ( |
| video_frame_args["max_frames"] > 0 |
| and frames_to_extract > video_frame_args["max_frames"] |
| ): |
| video_frame_args["target_frames"] = video_frame_args["max_frames"] |
| video_frame_args["fps"] = -1 |
|
|
| return video_frame_args |
|
|
| def _compute_3d_positions( |
| self, t: int, h: int, w: int, start_idx: int |
| ) -> List[List[int]]: |
| |
| t_eff = t // self.temporal_conv_size if t != 1 else 1 |
| gh, gw = h // self.spatial_conv_size, w // self.spatial_conv_size |
| time_idx = np.repeat(np.arange(t_eff), gh * gw) |
| h_idx = np.tile(np.repeat(np.arange(gh), gw), t_eff) |
| w_idx = np.tile(np.arange(gw), t_eff * gh) |
|
|
| coords = list(zip(time_idx, h_idx, w_idx)) |
| return [ |
| [start_idx + ti, start_idx + hi, start_idx + wi] for ti, hi, wi in coords |
| ] |
|
|
| def _pack_outputs(self, outs: Dict) -> Dict[str, Any]: |
| |
| if not outs["images"]: |
| outs["images"] = None |
| outs["grid_thw"] = None |
| outs["image_type_ids"] = None |
| else: |
| outs["images"] = np.vstack(outs["images"]) |
| outs["grid_thw"] = np.vstack(outs["grid_thw"]) |
| outs["image_type_ids"] = np.array(outs["image_type_ids"]) |
|
|
| |
| outs["input_ids"] = np.array(outs["input_ids"], dtype=np.int64) |
| outs["token_type_ids"] = np.array(outs["token_type_ids"], dtype=np.int64) |
| outs["position_ids"] = np.array(outs["position_ids"], dtype=np.int64) |
| return outs |
|
|
| def batch_decode(self, *args, **kwargs): |
| """ |
| This method forwards all its arguments to Ernie4_5_VLTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please |
| refer to the docstring of this method for more information. |
| """ |
| return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
| def decode(self, *args, **kwargs): |
| """ |
| This method forwards all its arguments to Ernie4_5_VLTokenizer's [`~PreTrainedTokenizer.decode`]. |
| Please refer to the docstring of this method for more information. |
| """ |
| return self.tokenizer.decode(*args, **kwargs) |
|
|
| @property |
| def model_input_names(self): |
| """get model input names""" |
| tokenizer_input_names = self.tokenizer.model_input_names |
| image_processor_input_names = self.image_processor.model_input_names |
| return list(tokenizer_input_names) + list(image_processor_input_names) |
|
|
|
|
| __all__ = ["Ernie4_5_VLTokenizer", "Ernie4_5_VLImageProcessor", "Ernie4_5_VLProcessor"] |