| import torch |
| from PIL import Image |
| import os |
| import numpy as np |
|
|
| from transformers import AutoModel, AutoTokenizer, AutoImageProcessor |
|
|
| from types import MethodType |
| from detectron2.data import MetadataCatalog |
| from detectron2.utils.visualizer import ColorMode, Visualizer |
|
|
| from transformers import AutoModel, AutoTokenizer, AutoImageProcessor |
| from detectron2.data.detection_utils import read_image |
| from detectron2.utils.visualizer import GenericMask |
| import matplotlib.colors as mplc |
| def draw_instance_predictions_cache(self, labels, np_masks, jittering: bool = True): |
| """ |
| Draw instance-level prediction results on an image. |
| |
| Args: |
| predictions (Instances): the output of an instance detection/segmentation |
| model. Following fields will be used to draw: |
| "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). |
| jittering: if True, in color mode SEGMENTATION, randomly jitter the colors per class |
| to distinguish instances from the same class |
| |
| Returns: |
| output (VisImage): image object with visualizations. |
| """ |
| boxes = None |
| scores = None |
| classes = None |
| keypoints = None |
|
|
| masks = [GenericMask(x, self.output.height, self.output.width) for x in np_masks] |
|
|
|
|
| if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): |
| colors = ( |
| [self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes] |
| if jittering |
| else [ |
| tuple(mplc.to_rgb([x / 255 for x in self.metadata.thing_colors[c]])) |
| for c in classes |
| ] |
| ) |
|
|
| alpha = 0.8 |
| else: |
| colors = None |
| alpha = 0.5 |
|
|
| self.overlay_instances( |
| masks=masks, |
| boxes=boxes, |
| labels=labels, |
| keypoints=keypoints, |
| assigned_colors=colors, |
| alpha=alpha, |
| ) |
| return self.output |
|
|
|
|
| def visualize(image_path, cat_masks, out_path, tags): |
| if tags is None: |
| left_tags = [f'{i}' for i in range(len(cat_masks))] |
| else: |
| left_tags = tags |
|
|
| unique_tags = list(set(left_tags)) |
| text_prompt = ','.join(unique_tags) |
| metadata = MetadataCatalog.get("__unused_ape_" + text_prompt) |
| metadata.thing_classes = unique_tags |
| metadata.stuff_classes = unique_tags |
|
|
| result_masks = cat_masks |
| input_image = read_image(image_path, format="BGR") |
| visualizer = Visualizer(input_image[:, :, ::-1], metadata, instance_mode=ColorMode.IMAGE) |
| visualizer.draw_instance_predictions = MethodType(draw_instance_predictions_cache, visualizer) |
| vis_output = visualizer.draw_instance_predictions(labels=left_tags, np_masks=result_masks) |
| output_image = vis_output.get_image() |
| output_image = Image.fromarray(output_image) |
|
|
| output_image.save(out_path) |
|
|
| path = "./work_dirs/hf_pano_vlm" |
| model = AutoModel.from_pretrained( |
| path, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| use_flash_attn=True, |
| trust_remote_code=True).eval().cuda() |
|
|
| tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False) |
|
|
| image_path = "./FRAME02_ORI.jpg" |
| image = Image.open(image_path) |
| width, height = image.size |
|
|
| from projects.llava_sam2.datasets.coco_category import COCO_CATEGORIES |
| coco_category_names = "" |
| for item in COCO_CATEGORIES: |
| class_name = item['name'] |
| coco_category_names += f"<p>{class_name}</p> [CLS], " |
| coco_category_names = coco_category_names[:-2] |
| |
| question = f"<image>\nSegment from the class prompt: <p>person</p> [CLS], <p>car</p> [CLS], <p>road</p> [CLS], <p>tree</p> [CLS], <p>building</p> [CLS], <p>ground</p> [CLS]." |
|
|
| m2f_processor = AutoImageProcessor.from_pretrained("./facebook/mask2former-swin-large-coco-panoptic", trust_remote_code=True,) |
|
|
| chat_outputs = model.predict_forward(text=question, image=image, tokenizer=tokenizer, m2f_processor=m2f_processor) |
| answer = chat_outputs['prediction'] |
| masks = chat_outputs['prediction_masks'] |
|
|
| m2f_outputs = chat_outputs['m2f_outputs'] |
|
|
| label_id_to_text = m2f_outputs['label_id_to_text'] |
|
|
| post_m2f_outputs = model.post_process_panoptic_segmentation( |
| m2f_outputs['class_queries_logits'], |
| m2f_outputs['masks_queries_logits'], |
| target_sizes=[(height, width)], |
| ) |
|
|
| print(f"user: {question}") |
| print(f"assistant: {answer}") |
|
|
| segmentation = post_m2f_outputs[0]['segmentation'] |
| segments_info = post_m2f_outputs[0]['segments_info'] |
| pano_masks, pano_tags = [], [] |
| for item in segments_info: |
| mask = segmentation == item['id'] |
| pano_masks.append(mask.unsqueeze(0).cpu().numpy()) |
| pano_tags.append(label_id_to_text[item['label_id']]) |
|
|
| pano_masks = np.concatenate(pano_masks, axis=0) |
|
|
| visualize(image_path, pano_masks, "./visualize_test_4.jpg", pano_tags) |
| |
|
|
|
|