| """ |
| Processor class for Molmo. |
| """ |
|
|
| from typing import Optional |
|
|
| import PIL |
| from PIL import Image |
|
|
| try: |
| from typing import Unpack |
| except ImportError: |
| from typing_extensions import Unpack |
|
|
| import re |
| from typing import List, Optional, Union |
|
|
| import numpy as np |
| import torch |
| import torchvision.transforms.functional as F |
| from transformers import AutoTokenizer |
| from transformers.image_utils import ImageInput |
| from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin, |
| TextKwargs) |
| from transformers.tokenization_utils_base import PreTokenizedInput, TextInput |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
|
|
| IGNORE_INDEX = -100 |
| DEFAULT_PAD_TOKEN_INDEX = 0 |
| IMAGE_TOKEN_INDEX = -200 |
| DEFAULT_IMAGE_TOKEN = "<image>" |
|
|
| |
| DEFAULT_OBJECT_TOKEN = "<obj<i>>" |
| DEFAULT_OBJECT_FEATURE_TOKEN = "<objfeat>" |
| DEFAULT_OBJECT_INDEX = -300 |
|
|
| |
| DEFAULT_GROUNDING_START = "<ground>" |
| DEFAULT_GROUNDING_END = "</ground>" |
| DEFAULT_GROUNDING_OBJECTS_START = "<objects>" |
| DEFAULT_GROUNDING_OBJECTS_END = "</objects>" |
|
|
| def xyxy_to_xywh(boxes): |
| """ |
| Convert boxes from xywh to xyxy format. |
| |
| Parameters: |
| boxes (numpy.ndarray): An array of shape (N, 4) where N is the number of boxes. |
| Each box is represented as [x, y, x, y]. |
| |
| Returns: |
| numpy.ndarray: An array of shape (N, 4) where each box is represented as [x_min, y_min, w, h]. |
| """ |
| boxes = np.array(boxes) |
| x_min, y_min, x_max, y_max = ( |
| boxes[:, 0], |
| boxes[:, 1], |
| boxes[:, 2], |
| boxes[:, 3], |
| ) |
| w = x_max - x_min |
| h = y_max - y_min |
| return np.stack([x_min, y_min, w, h], axis=1) |
|
|
|
|
| def xywh_to_xyxy(boxes): |
| """ |
| Convert boxes from xywh to xyxy format. |
| |
| Parameters: |
| boxes (numpy.ndarray): An array of shape (N, 4) where N is the number of boxes. |
| Each box is represented as [x, y, width, height]. |
| |
| Returns: |
| numpy.ndarray: An array of shape (N, 4) where each box is represented as [x_min, y_min, x_max, y_max]. |
| """ |
| boxes = np.array(boxes) |
| x, y, width, height = ( |
| boxes[:, 0], |
| boxes[:, 1], |
| boxes[:, 2], |
| boxes[:, 3], |
| ) |
| x_max = x + width |
| y_max = y + height |
| return np.stack([x, y, x_max, y_max], axis=1) |
|
|
| def expand2square(pil_img, background_color): |
| width, height = pil_img.size |
| if width == height: |
| return pil_img |
| elif width > height: |
| result = Image.new(pil_img.mode, (width, width), background_color) |
| result.paste(pil_img, (0, (width - height) // 2)) |
| return result |
| else: |
| result = Image.new(pil_img.mode, (height, height), background_color) |
| result.paste(pil_img, ((height - width) // 2, 0)) |
| return result |
|
|
| def pad_boxes(gt_boxes, old_size): |
| old_w, old_h = old_size |
| gt_boxes = np.array(gt_boxes).astype(np.float32) |
| |
| if old_w > old_h: |
| pad_top = (old_w - old_h) // 2 |
| pad_bottom = old_w - old_h - pad_top |
| pad_left, pad_right = 0, 0 |
| else: |
| pad_left = (old_h - old_w) // 2 |
| pad_right = old_h - old_w - pad_left |
| pad_top, pad_bottom = 0, 0 |
|
|
| |
| gt_boxes[:, 0] += pad_left |
| gt_boxes[:, 1] += pad_top |
| return gt_boxes |
|
|
|
|
| def resize_boxes(gt_boxes, old_size, new_size): |
| old_w, old_h = old_size |
| new_h, new_w = new_size |
| gt_boxes = np.array(gt_boxes).astype(np.float32) |
| |
| scale_x = new_w / max(old_w, old_h) |
| scale_y = new_h / max(old_w, old_h) |
|
|
| |
| gt_boxes[:, 0] *= scale_x |
| gt_boxes[:, 1] *= scale_y |
| gt_boxes[:, 2] *= scale_x |
| gt_boxes[:, 3] *= scale_y |
|
|
| return gt_boxes |
|
|
| def split_special_strings(input_string: str, special_strings: list[str] = None): |
| """Split the input string into a list of strings, keeping the special strings. |
| |
| Args: |
| input_string (str): The input string to split. |
| |
| Example: |
| |
| input_string = "<image>\n<obj0><objfeat><obj1><objfeat>\n I am happy today." |
| output = ['<image>', '\n<obj0>', '<objfeat>', '<obj1>', '<objfeat>', '\n I am happy today.'] |
| |
| Returns: |
| list: A list of strings, with the special strings separated from the rest of the input string. |
| """ |
| |
| pattern = "|".join(map(re.escape, special_strings)) |
|
|
| |
| split_list = re.split(f"({pattern})", input_string) |
|
|
| |
| split_list = [s for s in split_list if s] |
|
|
| return split_list |
|
|
| def tokenizer_image_object_token(prompt, tokenizer): |
| bos_token_id = tokenizer.bos_token_id |
| split_tokens = [DEFAULT_IMAGE_TOKEN, DEFAULT_OBJECT_FEATURE_TOKEN] |
| chunks = split_special_strings(prompt, split_tokens) |
| input_encode = [bos_token_id] |
| for chunk in chunks: |
| if chunk == DEFAULT_IMAGE_TOKEN: |
| input_encode.append(IMAGE_TOKEN_INDEX) |
| elif chunk == DEFAULT_OBJECT_FEATURE_TOKEN: |
| input_encode.append(DEFAULT_OBJECT_INDEX) |
| else: |
| input_encode.extend(tokenizer.encode(chunk, add_special_tokens=False)) |
| return input_encode |
|
|
| class ChatRexProcessor(ProcessorMixin): |
| attributes = ["image_processor", "tokenizer"] |
| image_processor_class = "AutoImageProcessor" |
| tokenizer_class = "AutoTokenizer" |
|
|
| def __init__(self, image_processor = None, tokenizer : AutoTokenizer = None, **kwargs): |
| |
| |
| super().__init__(image_processor, tokenizer) |
| self._special_tokens = None |
| self.template = dict( |
| SYSTEM=('A chat between a curious user and an artificial ' |
| 'intelligence assistant. The assistant gives ' |
| 'helpful, detailed, and polite answers to the ' |
| 'user\'s questions. {system}\n '), |
| INSTRUCTION=('USER: {input} ASSISTANT:'), |
| SEP='\n') |
|
|
| def process( |
| self, |
| image: Union[str, Image.Image], |
| bbox: List[List[int]], |
| question: str, |
| ): |
| """Prepare input data for inference. |
| |
| Args: |
| image (Union[str, Image.Image]): The image to process. |
| bbox (List[List[int]]): A list of bounding boxes for the image. Each bounding box should |
| be in order of [x, y, x , y]. |
| question (str): The question to ask about the image. |
| """ |
| data_dict = {} |
| |
| if type(image) == str: |
| image = Image.open(image).convert("RGB") |
| ori_w, ori_h = F.get_image_size(image) |
| image = expand2square( |
| image, |
| tuple(int(x * 255) for x in self.image_processor.image_mean), |
| ) |
| pad_w, pad_h = F.get_image_size(image) |
| image_aux = self.image_processor.preprocess(image, return_tensors="pt")[ |
| "pixel_values" |
| ][0] |
| resize_h, resize_w = image_aux.shape[-2:] |
| data_dict["pixel_values_aux"] = image_aux.unsqueeze(0) |
| image = image_aux.clone() |
| image = torch.nn.functional.interpolate( |
| image[None], |
| size=[336, 336], |
| mode="bilinear", |
| align_corners=False, |
| )[0] |
| data_dict["pixel_values"] = image.unsqueeze(0) |
|
|
| |
| bbox= xyxy_to_xywh(bbox) |
| bbox = pad_boxes(bbox, (ori_w, ori_h)) |
| bbox = resize_boxes(bbox, (pad_w, pad_h), (resize_h, resize_w)) |
| data_dict["gt_boxes"] = torch.tensor(xywh_to_xyxy(bbox)).unsqueeze(0) |
|
|
| |
| total_num_boxes = len(bbox) |
| obj_tokens = [ |
| DEFAULT_OBJECT_TOKEN.replace("<i>", str(i)) for i in range(total_num_boxes) |
| ] |
| obj_tokens = ( |
| DEFAULT_OBJECT_FEATURE_TOKEN.join(obj_tokens) + DEFAULT_OBJECT_FEATURE_TOKEN |
| ) |
| question = question.replace(DEFAULT_IMAGE_TOKEN, "") |
| question = DEFAULT_IMAGE_TOKEN + "\n" + obj_tokens + "\n" + question |
|
|
|
|
| inputs = "" |
| inputs += self.template["INSTRUCTION"].format(input=question, round=1) |
|
|
| |
| input_ids = tokenizer_image_object_token(inputs, self.tokenizer) |
| data_dict["input_ids"] = torch.tensor(input_ids).unsqueeze(0) |
|
|
| return data_dict |
|
|
| ChatRexProcessor.register_for_auto_class() |