|
|
| import os |
| import io |
| import jax |
| import base64 |
| import warnings |
| import functools |
| import numpy as np |
| import sentencepiece |
| import ml_collections |
| from PIL import Image |
| import big_vision.utils |
| import tensorflow as tf |
| import supervision as sv |
| import big_vision.sharding |
| from typing import Tuple, List, Optional |
| from big_vision.models.proj.paligemma import paligemma |
| from big_vision.trainers.proj.paligemma import predict_fns |
|
|
|
|
|
|
| SEQLEN = 128 |
|
|
| class PaliGemmaManager: |
| _instance = None |
|
|
| def __new__(cls, *args, **kwargs): |
| if not cls._instance: |
| cls._instance = super(PaliGemmaManager, cls).__new__(cls) |
| return cls._instance |
|
|
| def __init__(self, model, params, tokenizer): |
| self.model = model |
| self.params = params |
| self.tokenizer = tokenizer |
| self.decode_fn = None |
| self.decode = None |
| self.mesh = None |
| self.data_sharding = None |
| self.params_sharding = None |
| self.trainable_mask = None |
|
|
| self.initialise_model() |
|
|
|
|
| def initialise_model(self): |
| self.decode_fn = predict_fns.get_all(self.model)['decode'] |
| self.decode = functools.partial(self.decode_fn, devices=jax.devices(), eos_token=self.tokenizer.eos_id()) |
|
|
| def is_trainable_param(name, param): |
| if name.startswith("llm/layers/attn/"): return True |
| if name.startswith("llm/"): return False |
| if name.startswith("img/"): return False |
| raise ValueError(f"Unexpected param name {name}") |
| self.trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, self.params) |
|
|
| self.mesh = jax.sharding.Mesh(jax.devices(), ("data")) |
|
|
| self.data_sharding = jax.sharding.NamedSharding( |
| self.mesh, jax.sharding.PartitionSpec("data")) |
|
|
| self.params_sharding = big_vision.sharding.infer_sharding( |
| self.params, strategy=[('.*', 'fsdp(axis="data")')], mesh=self.mesh) |
| def preprocess_image(self,image, size=224): |
| image = np.asarray(image) |
| if image.ndim == 2: |
| image = np.stack((image,)*3, axis=-1) |
|
|
| image = image[..., :3] |
| assert image.shape[-1] == 3 |
|
|
| image = tf.constant(image) |
| image = tf.image.resize(image, (size, size), method='bilinear', antialias=True) |
| return image.numpy() / 127.5 - 1.0 |
|
|
| def preprocess_tokens(self, prefix, suffix=None, seqlen=None): |
| separator = "\n" |
| tokens = self.tokenizer.encode(prefix, add_bos=True) + self.tokenizer.encode(separator) |
| mask_ar = [0] * len(tokens) |
| mask_loss = [0] * len(tokens) |
|
|
| if suffix: |
| suffix = self.tokenizer.encode(suffix, add_eos=True) |
| tokens += suffix |
| mask_ar += [1] * len(suffix) |
| mask_loss += [1] * len(suffix) |
|
|
| mask_input = [1] * len(tokens) |
| if seqlen: |
| padding = [0] * max(0, seqlen - len(tokens)) |
| tokens = tokens[:seqlen] + padding |
| mask_ar = mask_ar[:seqlen] + padding |
| mask_loss = mask_loss[:seqlen] + padding |
| mask_input = mask_input[:seqlen] + padding |
|
|
| return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input)) |
|
|
| def postprocess_tokens(self, tokens): |
| tokens = tokens.tolist() |
| try: |
| eos_pos = tokens.index(self.tokenizer.eos_id()) |
| tokens = tokens[:eos_pos] |
| except ValueError: |
| pass |
| return self.tokenizer.decode(tokens) |
|
|
| def split_and_keep_second_part(s): |
| parts = s.split('\n', 1) |
| if len(parts) > 1: |
| return parts[1] |
| return s |
|
|
| def data_iterator(self, image_bytes, caption): |
| image = Image.open(io.BytesIO(image_bytes)) |
| image = self.preprocess_image(image) |
| tokens, mask_ar, _, mask_input = self.preprocess_tokens(caption, seqlen=SEQLEN) |
|
|
| yield { |
| "image": np.asarray(image), |
| "text": np.asarray(tokens), |
| "mask_ar": np.asarray(mask_ar), |
| "mask_input": np.asarray(mask_input), |
| } |
|
|
| def make_predictions(self, data_iterator, *, num_examples=None, |
| batch_size=4, seqlen=SEQLEN, sampler="greedy"): |
| outputs = [] |
| while True: |
| examples = [] |
| try: |
| for _ in range(batch_size): |
| examples.append(next(data_iterator)) |
| examples[-1]["_mask"] = np.array(True) |
| except StopIteration: |
| if len(examples) == 0: |
| return outputs |
|
|
| |
| while len(examples) % batch_size: |
| examples.append(dict(examples[-1])) |
| examples[-1]["_mask"] = np.array(False) |
|
|
|
|
| batch = jax.tree.map(lambda *x: np.stack(x), *examples) |
| batch = big_vision.utils.reshard(batch, self.data_sharding) |
| tokens = self.decode({"params": self.params}, batch=batch, |
| max_decode_len=seqlen, sampler=sampler) |
|
|
| |
| tokens, mask = jax.device_get((tokens, batch["_mask"])) |
| tokens = tokens[mask] |
| responses = [self.postprocess_tokens(t) for t in tokens] |
|
|
| for example, response in zip(examples, responses): |
| outputs.append((example["image"], response)) |
| if num_examples and len(outputs) >= num_examples: |
| return outputs |
|
|
| def process_result_to_bbox(self, image, caption, classes, w, h): |
| image = ((image + 1)/2 * 255).astype(np.uint8) |
|
|
| try: |
| detections = sv.Detections.from_lmm( |
| lmm='paligemma', |
| result=caption, |
| resolution_wh=(w, h), |
| classes=caption) |
|
|
| xyxy = list(detections.xyxy[0]) |
| x1, y1, x2, y2 = xyxy[0], xyxy[1], xyxy[2], xyxy[3] |
| width = x2 - x1 |
| height = y2 - y1 |
| output = [x1, y1, width, height] |
| except Exception as e: |
| print('Error detection') |
| print(e) |
| output = [0,0,0,0] |
|
|
| return output |
|
|
| def predict(self, image: bytes, caption: str) -> List[int]: |
| image_original = Image.open(io.BytesIO(image)) |
| original_width, original_height = image_original.size |
| if "detect" not in caption: |
| caption = f"detect {caption}" |
| |
| for image, response in self.make_predictions(self.data_iterator(image, caption), num_examples=1): |
| classes = response.replace("detect ", "") |
| output = self.process_result_to_bbox(image, response, classes, original_width, original_height) |
|
|
| return (output, response) |
|
|
|
|
|
|
|
|
|
|
|
|
| INFERENCE_IMAGE = '3_(backup)AdityaBY_img_14.png' |
| INFERENCE_PROMPT = "A mother takes a picture of her daughter holding a colourful wind spinner in front of the entrance." |
|
|
|
|
|
|
|
|
| TOKENIZER_PATH = '/home/lyka/air/Paligemma/pali-package/pali_open_vocab_annotations_tokenizer.model' |
| MODEL_PATH = '/home/lyka/air/Paligemma/pali-package/pali_open_vocab_annotations_segmentation.npz' |
|
|
|
|
| model_config = ml_collections.FrozenConfigDict({ |
| "llm": {"vocab_size": 257_152}, |
| "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"} |
| }) |
| model = paligemma.Model(**model_config) |
| tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH) |
|
|
| |
| params = paligemma.load(None, MODEL_PATH, model_config) |
| paligemma_manager = PaliGemmaManager(model, params, tokenizer) |
|
|
| with open(INFERENCE_IMAGE, 'rb') as f: |
| image_bytes = f.read() |
|
|
| output, response = paligemma_manager.predict(image_bytes, |
| INFERENCE_PROMPT) |
| image = Image.open(INFERENCE_IMAGE) |
| detections = sv.Detections.from_lmm( |
| lmm='paligemma', |
| result=response, |
| resolution_wh=image.size, |
| classes=response) |
|
|
| coordinates = detections.xyxy[0] |
| x1, y1, x2, y2 = coordinates |
|
|
| print('x1,y1,x2,y2:',coordinates) |