| from fractions import Fraction |
|
|
| from transformers import LlavaNextProcessor |
| from transformers.image_processing_utils import select_best_resolution |
|
|
|
|
|
|
| class Granite4VisionProcessor(LlavaNextProcessor): |
| model_type = "granite4_vision" |
|
|
| def __init__( |
| self, |
| image_processor=None, |
| tokenizer=None, |
| patch_size=None, |
| vision_feature_select_strategy=None, |
| chat_template=None, |
| image_token="<image>", |
| num_additional_image_tokens=0, |
| downsample_rate=None, |
| **kwargs, |
| ): |
| super().__init__(image_processor=image_processor, |
| tokenizer=tokenizer, |
| patch_size=patch_size, |
| vision_feature_select_strategy=vision_feature_select_strategy, |
| chat_template=chat_template, |
| image_token=image_token, |
| num_additional_image_tokens=num_additional_image_tokens, |
| ) |
| self.downsample_rate = downsample_rate |
|
|
| def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int: |
| image_grid_pinpoints = self.image_processor.image_grid_pinpoints |
|
|
| height_best_resolution, width_best_resolution = select_best_resolution( |
| [orig_height, orig_width], image_grid_pinpoints |
| ) |
| scale_height, scale_width = height_best_resolution // height, width_best_resolution // width |
|
|
| patches_height = height // self.patch_size |
| patches_width = width // self.patch_size |
| if self.downsample_rate is not None: |
| ds_rate = Fraction(self.downsample_rate) |
| patches_height = int(patches_height * ds_rate) |
| patches_width = int(patches_width * ds_rate) |
| |
| unpadded_features, newline_features = self._get_unpadded_features( |
| orig_height, orig_width, patches_height, patches_width, scale_height, scale_width |
| ) |
| |
| base_features = patches_height * patches_width + self.num_additional_image_tokens |
| num_image_tokens = unpadded_features + newline_features + base_features |
| return num_image_tokens |
|
|