| import math |
| from typing import List, Optional, TypeAlias, Union |
|
|
| import torch |
| import torch.nn as nn |
| from PIL import Image |
| from tqdm import tqdm |
| from transformers import AutoModelForImageTextToText, AutoProcessor |
|
|
| ImageInput: TypeAlias = Union[Image.Image, List[Image.Image]] |
| BatchImageInput: TypeAlias = Union[List[Image.Image], List[List[Image.Image]]] |
|
|
|
|
| class OpsMMEmbeddingV1(nn.Module): |
| def __init__( |
| self, |
| model_name: str, |
| device: str = "cuda", |
| max_length: Optional[int] = None, |
| attn_implementation: Optional[str] = None, |
| ): |
| super().__init__() |
| self.device = device |
| self.max_length = max_length |
| self.default_instruction = "You are a helpful assistant." |
| self.base_model = AutoModelForImageTextToText.from_pretrained( |
| model_name, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| attn_implementation=attn_implementation, |
| ).to(self.device) |
|
|
| self.processor = AutoProcessor.from_pretrained(model_name, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28) |
| self.processor.tokenizer.padding_side = "left" |
| self.eval() |
|
|
| def encode_input(self, input): |
| hidden_states = self.base_model(**input, return_dict=True, output_hidden_states=True) |
| hidden_states = hidden_states.hidden_states[-1] |
| pooled_output = self._pooling(hidden_states) |
| return pooled_output |
|
|
| def _pooling(self, last_hidden_state): |
| batch_size = last_hidden_state.shape[0] |
| reps = last_hidden_state[torch.arange(batch_size), -1, :] |
| reps = torch.nn.functional.normalize(reps, p=2, dim=-1) |
| return reps |
|
|
| def _validate_instructions( |
| self, |
| texts: Optional[List[str]], |
| images: Optional[BatchImageInput], |
| instruction: Optional[Union[str, List[str]]], |
| ) -> List[str]: |
| """Validate and format instructions to match batch size""" |
| batch_size = max(len(x) if x is not None else 0 for x in [texts, images]) |
|
|
| if instruction is None: |
| return [self.default_instruction] * batch_size |
|
|
| if isinstance(instruction, str): |
| return [instruction] * batch_size |
|
|
| if isinstance(instruction, list): |
| if len(instruction) != batch_size: |
| raise ValueError(f"Length of instruction list ({len(instruction)}) must match batch size ({batch_size}) when texts/images are provided") |
| return instruction |
|
|
| raise TypeError("instruction must be str, List[str] or None") |
|
|
| def _process_images(self, images: ImageInput) -> List[Image.Image]: |
| """Convert single image or list of images to processed format""" |
| if isinstance(images, Image.Image) or isinstance(images, str): |
| return [fetch_image(images)] |
| return [fetch_image(i) for i in images] |
|
|
| def embed( |
| self, |
| texts: Optional[List[str]] = None, |
| images: Optional[BatchImageInput] = None, |
| instruction: Optional[Union[str, List[str]]] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| """Generate embeddings for text, images, or combined inputs. |
| |
| Args: |
| texts: List of text inputs (optional) |
| images: Can be: |
| - List[Image.Image]: Single image per input |
| - List[List[Image.Image]]: Multiple images per input |
| instruction: Instruction(s) for the model. Can be: |
| - None: use default instruction |
| - str: use same instruction for all inputs |
| - List[str]: per-input instructions (must match batch size) |
| """ |
| if texts is None and images is None: |
| raise ValueError("Either texts or images must be provided") |
|
|
| instructions = self._validate_instructions(texts, images, instruction) |
|
|
| |
| batch_size = len(texts) if texts is not None else len(images) |
|
|
| input_texts, input_images = [], [] |
| for i in range(batch_size): |
| text = texts[i] if texts is not None else None |
| image = images[i] if images is not None else None |
|
|
| input_str = "" |
| processed_image = None |
| if image is not None: |
| processed_image = self._process_images(image) |
| input_str += "<|vision_start|><|image_pad|><|vision_end|>" * len(processed_image) |
|
|
| if text is not None: |
| input_str += text |
|
|
| msg = f"<|im_start|>system\n{instructions[i]}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>" |
|
|
| input_texts.append(msg) |
| input_images.append(processed_image) |
|
|
| |
| processed_images = input_images if any(img is not None for img in input_images) else None |
|
|
| inputs = self.processor( |
| text=input_texts, |
| images=processed_images, |
| padding=True, |
| truncation=True, |
| max_length=self.max_length, |
| return_tensors="pt", |
| ) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
| with torch.inference_mode(): |
| embeddings = self.encode_input(inputs) |
|
|
| return embeddings |
|
|
| def get_text_embeddings( |
| self, |
| texts: List[str], |
| instruction: Optional[Union[str, List[str]]] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| """Convenience method for text-only embeddings""" |
| return self.get_fused_embeddings(texts=texts, instruction=instruction, **kwargs) |
|
|
| def get_image_embeddings( |
| self, |
| images: BatchImageInput, |
| instruction: Optional[Union[str, List[str]]] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| """Convenience method for image-only embeddings. |
| |
| Args: |
| images: Can be: |
| - List[Image.Image]: Single image per input |
| - List[List[Image.Image]]: Multiple images per input |
| """ |
| return self.get_fused_embeddings(images=images, instruction=instruction, **kwargs) |
|
|
| def get_fused_embeddings( |
| self, |
| texts: Optional[List[str]] = None, |
| images: Optional[BatchImageInput] = None, |
| instruction: Optional[Union[str, List[str]]] = None, |
| batch_size: int = 8, |
| show_progress: bool = True, |
| **kwargs, |
| ) -> torch.Tensor: |
| """Batch processing for large collections of texts/images. |
| |
| Args: |
| texts: List of text inputs (optional) |
| images: Can be: |
| - List[Image.Image]: Single image per input |
| - List[List[Image.Image]]: Multiple images per input |
| instruction: Instruction(s) for the model |
| batch_size: Number of items to process at once |
| show_progress: Whether to display progress bar |
| """ |
|
|
| if texts is None and images is None: |
| raise ValueError("Either texts or images must be provided") |
|
|
| total_items = len(texts) if texts is not None else len(images) |
| num_batches = math.ceil(total_items / batch_size) |
|
|
| all_embeddings = [] |
| progress = tqdm(total=num_batches, disable=not show_progress, desc="Processing") |
|
|
| for i in range(0, total_items, batch_size): |
| batch_texts = texts[i : i + batch_size] if texts is not None else None |
| batch_images = images[i : i + batch_size] if images is not None else None |
| batch_emb = self.embed(texts=batch_texts, images=batch_images, instruction=instruction) |
|
|
| all_embeddings.append(batch_emb.cpu()) |
| progress.update(1) |
|
|
| progress.close() |
| return torch.cat(all_embeddings, dim=0).to(self.device) |
|
|
| def forward(self, **inputs) -> torch.Tensor: |
| """Alias for encode_input""" |
| return self.encode_input(inputs) |
|
|
|
|
| |
| import base64 |
| import logging |
| import math |
| from io import BytesIO |
|
|
| import requests |
|
|
| IMAGE_FACTOR = 28 |
| MIN_PIXELS = 256 * 28 * 28 |
| MAX_PIXELS = 1280 * 28 * 28 |
| MAX_RATIO = 200 |
|
|
|
|
| def round_by_factor(number: int, factor: int) -> int: |
| """Returns the closest integer to 'number' that is divisible by 'factor'.""" |
| return round(number / factor) * factor |
|
|
|
|
| def ceil_by_factor(number: int | float, factor: int) -> int: |
| """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" |
| return math.ceil(number / factor) * factor |
|
|
|
|
| def floor_by_factor(number: int | float, factor: int) -> int: |
| """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" |
| return math.floor(number / factor) * factor |
|
|
|
|
| def smart_resize( |
| height: int, |
| width: int, |
| factor: int = IMAGE_FACTOR, |
| min_pixels: int = MIN_PIXELS, |
| max_pixels: int = MAX_PIXELS, |
| ) -> tuple[int, int]: |
| """ |
| Rescales the image so that the following conditions are met: |
| 1. Both dimensions (height and width) are divisible by 'factor'. |
| 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. |
| 3. The aspect ratio of the image is maintained as closely as possible. |
| """ |
| h_bar = max(factor, round_by_factor(height, factor)) |
| w_bar = max(factor, round_by_factor(width, factor)) |
| if h_bar * w_bar > max_pixels: |
| beta = math.sqrt((height * width) / max_pixels) |
| h_bar = floor_by_factor(height / beta, factor) |
| w_bar = floor_by_factor(width / beta, factor) |
| elif h_bar * w_bar < min_pixels: |
| beta = math.sqrt(min_pixels / (height * width)) |
| h_bar = ceil_by_factor(height * beta, factor) |
| w_bar = ceil_by_factor(width * beta, factor) |
|
|
| if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO: |
| logging.warning(f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}") |
| if h_bar > w_bar: |
| h_bar = w_bar * MAX_RATIO |
| else: |
| w_bar = h_bar * MAX_RATIO |
| return h_bar, w_bar |
|
|
|
|
| def fetch_image( |
| image: str | Image.Image, |
| size_factor: int = IMAGE_FACTOR, |
| min_pixels: int = MIN_PIXELS, |
| max_pixels: int = MAX_PIXELS, |
| ) -> Image.Image: |
| image_obj = None |
| if isinstance(image, Image.Image): |
| image_obj = image |
| elif image.startswith("http://") or image.startswith("https://"): |
| image_obj = Image.open(requests.get(image, stream=True).raw) |
| elif image.startswith("file://"): |
| image_obj = Image.open(image[7:]) |
| elif image.startswith("data:image"): |
| if "base64," in image: |
| _, base64_data = image.split("base64,", 1) |
| data = base64.b64decode(base64_data) |
| image_obj = Image.open(BytesIO(data)) |
| else: |
| image_obj = Image.open(image) |
| if image_obj is None: |
| raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") |
| image = image_obj.convert("RGB") |
| width, height = image.size |
| resized_height, resized_width = smart_resize( |
| height, |
| width, |
| factor=size_factor, |
| min_pixels=min_pixels, |
| max_pixels=max_pixels, |
| ) |
| image = image.resize((resized_width, resized_height)) |
|
|
| return image |
|
|
|
|
| |
|
|