| from __future__ import annotations |
| from torch import nn |
| import logging |
| import math |
| import os |
| from typing import Dict, List, Optional |
|
|
| import torch |
| from PIL import Image |
| from torch.utils.data import DataLoader |
| from tqdm.autonotebook import tqdm |
| from transformers import AutoModelForVision2Seq, AutoProcessor, AutoConfig |
|
|
| from transformers.models.qwen2_vl import Qwen2VLForConditionalGeneration |
|
|
|
|
| class RzenEmbed(nn.Module): |
| def __init__( |
| self, |
| model_name: str = "qihoo360/RzenEmbed", |
| model_path: Optional[str] = None, |
| device: str = "cuda" if torch.cuda.is_available() else "cpu", |
| min_image_tokens=256, |
| max_image_tokens=1280, |
| min_video_tokens=160, |
| max_video_tokens=180, |
| max_length=2000, |
| attn_implementation="flash_attention_2", |
| processor: Optional[AutoProcessor] = None, |
| **kwargs, |
| ) -> None: |
| super().__init__() |
| model_name = model_path or model_name |
|
|
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) |
| config._attn_implementation = attn_implementation |
| config.padding_side = "right" |
| config.use_cache = False |
|
|
| self.base = Qwen2VLForConditionalGeneration.from_pretrained( |
| model_name, config=config, |
| torch_dtype=torch.bfloat16, low_cpu_mem_usage=True |
| ) |
|
|
| self.base.eval() |
| self.normalize = True |
| self.device = device |
| self.base = self.base.to(self.device) |
| print(f"model.device: {str(self.base.device)}") |
| min_pixels = min_image_tokens * 28 * 28 |
| max_pixels = max_image_tokens * 28 * 28 |
| self.max_length = max_length |
| if processor is None: |
| processor = AutoProcessor.from_pretrained( |
| model_name, min_pixels=min_pixels, max_pixels=max_pixels |
| ) |
| self.processor = processor |
| self.processor.tokenizer.padding_side = 'right' |
| self.defualt_instruction = 'You are a helpful assistant.' |
| self.sep = ' ' |
|
|
| min_pixels_video = min_video_tokens * 28 * 28 |
| max_pixels_video = max_video_tokens * 28 * 28 |
| self.qwen2vl_video_processor = AutoProcessor.from_pretrained( |
| model_name, min_pixels=min_pixels_video, max_pixels=max_pixels_video |
| ) |
| self.qwen2vl_video_processor.tokenizer.padding_side = 'right' |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| |
| image_grid_thw: Optional[torch.LongTensor] = None, |
| |
| pooling_mask: Optional[torch.LongTensor] = None, |
| **kwargs |
| ) -> torch.Tensor: |
| if inputs_embeds is None: |
| inputs_embeds = self.base.model.embed_tokens(input_ids) |
| has_image = (pixel_values is not None) and any([pv is not None for pv in pixel_values]) |
| if has_image: |
| if type(pixel_values) is list: |
| pixel_values = torch.cat([torch.from_numpy(pv) for pv in pixel_values]).to(input_ids.device) |
| image_grid_thw = torch.cat([torch.from_numpy(thw) for thw in image_grid_thw]).to(input_ids.device) |
| pixel_values = pixel_values.type(self.base.visual.get_dtype()) |
| image_embeds = self.base.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device) |
| image_mask = input_ids == self.base.config.image_token_id |
| inputs_embeds[image_mask] = image_embeds |
|
|
| if attention_mask is not None: |
| attention_mask = attention_mask.to(inputs_embeds.device) |
|
|
| |
| outputs = self.base.model( |
| input_ids=None, |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| ) |
|
|
| pooling_mask = attention_mask if pooling_mask is None else pooling_mask |
| left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) |
| if left_padding: |
| embeddings = outputs.last_hidden_state[:, -1] |
| else: |
| sequence_lengths = pooling_mask.sum(dim=1) - 1 |
| batch_size = outputs.last_hidden_state.shape[0] |
| embeddings = outputs.last_hidden_state[torch.arange( |
| batch_size, device=outputs.last_hidden_state.device |
| ), sequence_lengths] |
| if self.normalize: |
| embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) |
| return embeddings.contiguous() |
|
|
|
|
| def _process_images(self, images): |
| """Convert single image or list of images to processed format""" |
| if isinstance(images, Image.Image) or isinstance(images, str): |
| return [fetch_image(images)] |
| return [fetch_image(i) for i in images] |
|
|
|
|
| def embed(self, texts: list[str], images: list[Image.Image], **kwargs): |
| |
| |
| if any(isinstance(item, list) for item in images): |
| is_video = True |
| else: |
| is_video = False |
|
|
| if texts is None and images is None: |
| raise ValueError("Either texts or images must be provided") |
|
|
| |
| |
| batch_size = len(texts) if texts is not None else len(images) |
|
|
| input_texts, input_images = [], [] |
| instruction = self.defualt_instruction |
| for i in range(batch_size): |
| text = texts[i] if texts is not None else None |
| image = images[i] if images is not None else None |
|
|
| input_str = "" |
| processed_image = None |
| if image is not None: |
| processed_image = self._process_images(image) |
| input_images += processed_image |
| input_str += "<|vision_start|><|image_pad|><|vision_end|>" * len(processed_image) |
|
|
| if text is not None: |
| input_str += text |
|
|
| msg = f"<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>" |
|
|
| input_texts.append(msg) |
|
|
| if len(input_images) == 0: |
| input_images = None |
|
|
| if is_video: |
| inputs = self.qwen2vl_video_processor( |
| text=input_texts, |
| images=input_images, |
| padding=True, |
| truncation=True, |
| max_length=self.max_length, |
| return_tensors="pt", |
| ) |
| else: |
| inputs = self.processor( |
| text=input_texts, |
| images=input_images, |
| padding=True, |
| truncation=True, |
| max_length=self.max_length, |
| return_tensors="pt", |
| ) |
|
|
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| embeddings = self.forward(**inputs) |
| return embeddings |
|
|
| def encode(self, sentences: list[str], *, prompt_name=None, **kwargs): |
| return self.get_fused_embeddings(texts=sentences, prompt_name=prompt_name, **kwargs) |
|
|
| def get_image_embeddings(self, images: list[Image.Image] | DataLoader, **kwargs): |
| return self.get_fused_embeddings(images=images, **kwargs) |
|
|
| def get_text_embeddings(self, texts: list[str], **kwargs): |
| return self.get_fused_embeddings(texts=texts, **kwargs) |
|
|
| def get_fused_embeddings(self, texts: list[str] = None, images: list[Image.Image] | DataLoader = None, **kwargs): |
| assert texts or images, "Either 'texts' or 'images' must be provided - both cannot be None or empty" |
| instruction = kwargs.pop('instruction', None) |
| if instruction is not None: |
| if texts is not None: |
| texts = [instruction + text for text in texts] |
| else: |
| texts = [instruction] * len(images) |
|
|
| if isinstance(images, DataLoader): |
| image_loader = images |
| batch_size = image_loader.batch_size |
| image_loader.dataset.transform = None |
| else: |
| batch_size = kwargs.pop('batch_size', 32) |
| if images is None: |
| image_loader = None |
| else: |
| image_loader = DataLoader( |
| images, |
| batch_size=batch_size, |
| shuffle=False, |
| collate_fn=custom_collate_fn, |
| num_workers=min(math.floor(os.cpu_count() / 2), 8), |
| ) |
|
|
| if texts is None: |
| assert image_loader is not None |
| n_batch = len(image_loader) |
| else: |
| n_batch = len(texts) // batch_size + int(len(texts) % batch_size > 0) |
| image_loader = image_loader or [None] * n_batch |
|
|
| all_embeddings = list() |
| none_batch = [None] * batch_size |
| show_progress_bar = kwargs.pop('show_progress_bar', False) |
| pbar = tqdm(total=n_batch, disable=not show_progress_bar, mininterval=1, miniters=10, desc='encode') |
| for n, img_batch in zip(range(0, n_batch * batch_size, batch_size), image_loader): |
| text_batch = none_batch[:len(img_batch)] if texts is None else texts[n: n+batch_size] |
| img_batch = none_batch[:len(text_batch)] if img_batch is None else img_batch |
| embeddings = self.embed(texts=text_batch, images=img_batch, **kwargs) |
| pbar.update(1) |
| all_embeddings.append(embeddings.cpu()) |
| pbar.close() |
| all_embeddings = torch.cat(all_embeddings, dim=0) |
| return all_embeddings |
|
|
|
|
| def custom_collate_fn(batch): |
| return batch |
|
|
|
|
| |
| import base64 |
| from io import BytesIO |
| import requests |
|
|
| IMAGE_FACTOR = 28 |
| MIN_PIXELS = 4 * 28 * 28 |
| MAX_PIXELS = 16384 * 28 * 28 |
| MAX_RATIO = 200 |
|
|
|
|
| def round_by_factor(number: int, factor: int) -> int: |
| """Returns the closest integer to 'number' that is divisible by 'factor'.""" |
| return round(number / factor) * factor |
|
|
|
|
| def ceil_by_factor(number: int, factor: int) -> int: |
| """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" |
| return math.ceil(number / factor) * factor |
|
|
|
|
| def floor_by_factor(number: int, factor: int) -> int: |
| """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" |
| return math.floor(number / factor) * factor |
|
|
|
|
| def smart_resize( |
| height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS |
| ) -> tuple[int, int]: |
| """ |
| Rescales the image so that the following conditions are met: |
| |
| 1. Both dimensions (height and width) are divisible by 'factor'. |
| |
| 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. |
| |
| 3. The aspect ratio of the image is maintained as closely as possible. |
| """ |
| h_bar = max(factor, round_by_factor(height, factor)) |
| w_bar = max(factor, round_by_factor(width, factor)) |
| if h_bar * w_bar > max_pixels: |
| beta = math.sqrt((height * width) / max_pixels) |
| h_bar = floor_by_factor(height / beta, factor) |
| w_bar = floor_by_factor(width / beta, factor) |
| elif h_bar * w_bar < min_pixels: |
| beta = math.sqrt(min_pixels / (height * width)) |
| h_bar = ceil_by_factor(height * beta, factor) |
| w_bar = ceil_by_factor(width * beta, factor) |
|
|
| if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO: |
| logging.warning( |
| f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}" |
| ) |
| if h_bar > w_bar: |
| h_bar = w_bar * MAX_RATIO |
| else: |
| w_bar = h_bar * MAX_RATIO |
| return h_bar, w_bar |
|
|
|
|
| def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Image.Image: |
| image_obj = None |
| if isinstance(image, Image.Image): |
| image_obj = image |
| elif image.startswith("http://") or image.startswith("https://"): |
| headers = {'User-Agent': 'My User Agent 1.0'} |
| image_obj = Image.open(requests.get(image, headers=headers, stream=True).raw) |
| elif image.startswith("file://"): |
| image_obj = Image.open(image[7:]) |
| elif image.startswith("data:image"): |
| if "base64," in image: |
| _, base64_data = image.split("base64,", 1) |
| data = base64.b64decode(base64_data) |
| image_obj = Image.open(BytesIO(data)) |
| else: |
| image_obj = Image.open(image) |
| if image_obj is None: |
| raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") |
| image = image_obj.convert("RGB") |
|
|
| width, height = image.size |
|
|
| resized_height, resized_width = smart_resize( |
| height, |
| width, |
| factor=size_factor, |
| min_pixels=MIN_PIXELS, |
| max_pixels=MAX_PIXELS, |
| ) |
| image = image.resize((resized_width, resized_height)) |
|
|
| return image |
| |
|
|
|
|
| if __name__ == '__main__': |
| rzen = RzenEmbed("qihoo360/RzenEmbed") |
|
|
| queries = [ |
| "A curious kitten and a gentle puppy share a moment of connection on the grass.", |
| "Fresh fridge full of berries yogurt milk and snacks." |
| ] |
| candidates = [ |
| "assets/example1.jpg", |
| "assets/example2.jpg", |
| ] |
|
|
| query_instruction = "Find me an everyday image that matches the given caption: " |
| candidate_instruction = "Represent the given image." |
|
|
| |
| query_embeds = rzen.get_fused_embeddings(instruction=query_instruction, texts=queries) |
| candidate_embeds = rzen.get_fused_embeddings(instruction=candidate_instruction, images=candidates) |
|
|
| |
| similarity_scores = query_embeds @ candidate_embeds.T |
| print(similarity_scores) |