| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| token = os.environ["GITHUB_TOKEN"] |
| os.system(f"pip install git+https://xvjiarui:{token}@github.com/xvjiarui/ODISE_NV.git") |
|
|
| import itertools |
| import json |
| from contextlib import ExitStack |
| import gradio as gr |
| import torch |
| from mask2former.data.datasets.register_ade20k_panoptic import ADE20K_150_CATEGORIES |
| from PIL import Image |
| from torch.cuda.amp import autocast |
|
|
| from detectron2.config import instantiate |
| from detectron2.data import MetadataCatalog |
| from detectron2.data import detection_utils as utils |
| from detectron2.data import transforms as T |
| from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES |
| from detectron2.evaluation import inference_context |
| from detectron2.utils.env import seed_all_rng |
| from detectron2.utils.logger import setup_logger |
| from detectron2.utils.visualizer import ColorMode, Visualizer, random_color |
|
|
| from odise import model_zoo |
| from odise.checkpoint import ODISECheckpointer |
| from odise.config import instantiate_odise |
| from odise.data import get_openseg_labels |
| from odise.modeling.wrapper import OpenPanopticInference |
| from odise.utils.file_io import ODISEHandler, PathManager |
| from odise.model_zoo.model_zoo import _ModelZooUrls |
|
|
| for k in ODISEHandler.URLS: |
| ODISEHandler.URLS[k] = ODISEHandler.URLS[k].replace("https://github.com/NVlabs/ODISE/releases/download/v1.0.0/", "https://huggingface.co/xvjiarui/download_cache/resolve/main/torch/odise/") |
| PathManager.register_handler(ODISEHandler()) |
| _ModelZooUrls.PREFIX = _ModelZooUrls.PREFIX.replace("https://github.com/NVlabs/ODISE/releases/download/v1.0.0/", "https://huggingface.co/xvjiarui/download_cache/resolve/main/torch/odise/") |
|
|
| setup_logger() |
| logger = setup_logger(name="odise") |
|
|
| COCO_THING_CLASSES = [ |
| label |
| for idx, label in enumerate(get_openseg_labels("coco_panoptic", True)) |
| if COCO_CATEGORIES[idx]["isthing"] == 1 |
| ] |
| COCO_THING_COLORS = [c["color"] for c in COCO_CATEGORIES if c["isthing"] == 1] |
| COCO_STUFF_CLASSES = [ |
| label |
| for idx, label in enumerate(get_openseg_labels("coco_panoptic", True)) |
| if COCO_CATEGORIES[idx]["isthing"] == 0 |
| ] |
| COCO_STUFF_COLORS = [c["color"] for c in COCO_CATEGORIES if c["isthing"] == 0] |
|
|
| ADE_THING_CLASSES = [ |
| label |
| for idx, label in enumerate(get_openseg_labels("ade20k_150", True)) |
| if ADE20K_150_CATEGORIES[idx]["isthing"] == 1 |
| ] |
| ADE_THING_COLORS = [c["color"] for c in ADE20K_150_CATEGORIES if c["isthing"] == 1] |
| ADE_STUFF_CLASSES = [ |
| label |
| for idx, label in enumerate(get_openseg_labels("ade20k_150", True)) |
| if ADE20K_150_CATEGORIES[idx]["isthing"] == 0 |
| ] |
| ADE_STUFF_COLORS = [c["color"] for c in ADE20K_150_CATEGORIES if c["isthing"] == 0] |
|
|
| LVIS_CLASSES = get_openseg_labels("lvis_1203", True) |
| |
| LVIS_COLORS = list( |
| itertools.islice(itertools.cycle([c["color"] for c in COCO_CATEGORIES]), len(LVIS_CLASSES)) |
| ) |
|
|
|
|
| class VisualizationDemo(object): |
| def __init__(self, model, metadata, aug, instance_mode=ColorMode.IMAGE): |
| """ |
| Args: |
| model (nn.Module): |
| metadata (MetadataCatalog): image metadata. |
| instance_mode (ColorMode): |
| parallel (bool): whether to run the model in different processes from visualization. |
| Useful since the visualization logic can be slow. |
| """ |
| self.model = model |
| self.metadata = metadata |
| self.aug = aug |
| self.cpu_device = torch.device("cpu") |
| self.instance_mode = instance_mode |
|
|
| def predict(self, original_image): |
| """ |
| Args: |
| original_image (np.ndarray): an image of shape (H, W, C) (in BGR order). |
| |
| Returns: |
| predictions (dict): |
| the output of the model for one image only. |
| See :doc:`/tutorials/models` for details about the format. |
| """ |
| height, width = original_image.shape[:2] |
| aug_input = T.AugInput(original_image, sem_seg=None) |
| self.aug(aug_input) |
| image = aug_input.image |
| image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) |
|
|
| inputs = {"image": image, "height": height, "width": width} |
| logger.info("forwarding") |
| with autocast(): |
| predictions = self.model([inputs])[0] |
| logger.info("done") |
| return predictions |
|
|
| def run_on_image(self, image): |
| """ |
| Args: |
| image (np.ndarray): an image of shape (H, W, C) (in BGR order). |
| This is the format used by OpenCV. |
| Returns: |
| predictions (dict): the output of the model. |
| vis_output (VisImage): the visualized image output. |
| """ |
| vis_output = None |
| predictions = self.predict(image) |
| visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode) |
| if "panoptic_seg" in predictions: |
| panoptic_seg, segments_info = predictions["panoptic_seg"] |
| vis_output = visualizer.draw_panoptic_seg( |
| panoptic_seg.to(self.cpu_device), segments_info |
| ) |
| else: |
| if "sem_seg" in predictions: |
| vis_output = visualizer.draw_sem_seg( |
| predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) |
| ) |
| if "instances" in predictions: |
| instances = predictions["instances"].to(self.cpu_device) |
| vis_output = visualizer.draw_instance_predictions(predictions=instances) |
|
|
| return predictions, vis_output |
|
|
|
|
| cfg = model_zoo.get_config("Panoptic/odise_label_coco_50e.py", trained=True) |
|
|
| cfg.model.overlap_threshold = 0 |
| cfg.train.device = "cuda" if torch.cuda.is_available() else "cpu" |
| seed_all_rng(42) |
|
|
| dataset_cfg = cfg.dataloader.test |
| wrapper_cfg = cfg.dataloader.wrapper |
|
|
| aug = instantiate(dataset_cfg.mapper).augmentations |
|
|
| model = instantiate_odise(cfg.model) |
| model.to(torch.float16) |
| model.to(cfg.train.device) |
| ODISECheckpointer(model).load(cfg.train.init_checkpoint) |
|
|
|
|
| title = "ODISE" |
| description = """ |
| <p style='text-align: center'> <a href='https://jerryxu.net/ODISE' target='_blank'>Project Page</a> | <a href='https://arxiv.org/abs/2303.04803' target='_blank'>Paper</a> | <a href='https://github.com/NVlabs/ODISE' target='_blank'>Code</a> | <a href='https://youtu.be/Su7p5KYmcII' target='_blank'>Video</a></p> |
| |
| Gradio demo for ODISE: Open-Vocabulary Panoptic Segmentation with Text-to-Image Diffusion Models. \n |
| You may click on of the examples or upload your own image. \n |
| |
| ODISE could perform open vocabulary segmentation, you may input more classes (separate by comma). |
| The expected format is 'a1,a2;b1,b2', where a1,a2 are synonyms vocabularies for the first class. |
| The first word will be displayed as the class name. |
| """ |
|
|
| article = """ |
| <p style='text-align: center'><a href='https://arxiv.org/abs/2303.04803' target='_blank'>Open-Vocabulary Panoptic Segmentation with Text-to-Image Diffusion Models</a> | <a href='https://github.com/NVlab/ODISE' target='_blank'>Github Repo</a></p> |
| """ |
|
|
| examples = [ |
| [ |
| "demo/examples/coco.jpg", |
| "black pickup truck, pickup truck; blue sky, sky", |
| ["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"], |
| ], |
| [ |
| "demo/examples/ade.jpg", |
| "luggage, suitcase, baggage;handbag", |
| ["ADE (150 categories)"], |
| ], |
| [ |
| "demo/examples/ego4d.jpg", |
| "faucet, tap; kitchen paper, paper towels", |
| ["COCO (133 categories)"], |
| ], |
| ] |
|
|
|
|
| def build_demo_classes_and_metadata(vocab, label_list): |
| extra_classes = [] |
|
|
| if vocab: |
| for words in vocab.split(";"): |
| extra_classes.append([word.strip() for word in words.split(",")]) |
| extra_colors = [random_color(rgb=True, maximum=1) for _ in range(len(extra_classes))] |
|
|
| demo_thing_classes = extra_classes |
| demo_stuff_classes = [] |
| demo_thing_colors = extra_colors |
| demo_stuff_colors = [] |
|
|
| if any("COCO" in label for label in label_list): |
| demo_thing_classes += COCO_THING_CLASSES |
| demo_stuff_classes += COCO_STUFF_CLASSES |
| demo_thing_colors += COCO_THING_COLORS |
| demo_stuff_colors += COCO_STUFF_COLORS |
| if any("ADE" in label for label in label_list): |
| demo_thing_classes += ADE_THING_CLASSES |
| demo_stuff_classes += ADE_STUFF_CLASSES |
| demo_thing_colors += ADE_THING_COLORS |
| demo_stuff_colors += ADE_STUFF_COLORS |
| if any("LVIS" in label for label in label_list): |
| demo_thing_classes += LVIS_CLASSES |
| demo_thing_colors += LVIS_COLORS |
|
|
| MetadataCatalog.pop("odise_demo_metadata", None) |
| demo_metadata = MetadataCatalog.get("odise_demo_metadata") |
| demo_metadata.thing_classes = [c[0] for c in demo_thing_classes] |
| demo_metadata.stuff_classes = [ |
| *demo_metadata.thing_classes, |
| *[c[0] for c in demo_stuff_classes], |
| ] |
| demo_metadata.thing_colors = demo_thing_colors |
| demo_metadata.stuff_colors = demo_thing_colors + demo_stuff_colors |
| demo_metadata.stuff_dataset_id_to_contiguous_id = { |
| idx: idx for idx in range(len(demo_metadata.stuff_classes)) |
| } |
| demo_metadata.thing_dataset_id_to_contiguous_id = { |
| idx: idx for idx in range(len(demo_metadata.thing_classes)) |
| } |
|
|
| demo_classes = demo_thing_classes + demo_stuff_classes |
|
|
| return demo_classes, demo_metadata |
|
|
|
|
| def inference(image_path, vocab, label_list): |
|
|
| logger.info("building class names") |
| demo_classes, demo_metadata = build_demo_classes_and_metadata(vocab, label_list) |
| with ExitStack() as stack: |
| inference_model = OpenPanopticInference( |
| model=model, |
| labels=demo_classes, |
| metadata=demo_metadata, |
| semantic_on=False, |
| instance_on=False, |
| panoptic_on=True, |
| ) |
| stack.enter_context(inference_context(inference_model)) |
| stack.enter_context(torch.no_grad()) |
|
|
| demo = VisualizationDemo(inference_model, demo_metadata, aug) |
| img = utils.read_image(image_path, format="RGB") |
| _, visualized_output = demo.run_on_image(img) |
| return Image.fromarray(visualized_output.get_image()) |
|
|
|
|
| with gr.Blocks(title=title) as demo: |
| gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>") |
| gr.Markdown(description) |
| input_components = [] |
| output_components = [] |
|
|
| with gr.Row(): |
| output_image_gr = gr.outputs.Image(label="Panoptic Segmentation", type="pil") |
| output_components.append(output_image_gr) |
|
|
| with gr.Row().style(equal_height=True, mobile_collapse=True): |
| with gr.Column(scale=3, variant="panel") as input_component_column: |
| input_image_gr = gr.inputs.Image(type="filepath") |
| extra_vocab_gr = gr.inputs.Textbox(default="", label="Extra Vocabulary") |
| category_list_gr = gr.inputs.CheckboxGroup( |
| choices=["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"], |
| default=["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"], |
| label="Category to use", |
| ) |
| input_components.extend([input_image_gr, extra_vocab_gr, category_list_gr]) |
|
|
| with gr.Column(scale=2): |
| examples_handler = gr.Examples( |
| examples=examples, |
| inputs=[c for c in input_components if not isinstance(c, gr.State)], |
| outputs=[c for c in output_components if not isinstance(c, gr.State)], |
| fn=inference, |
| cache_examples=torch.cuda.is_available(), |
| examples_per_page=5, |
| ) |
| with gr.Row(): |
| clear_btn = gr.Button("Clear") |
| submit_btn = gr.Button("Submit", variant="primary") |
|
|
| gr.Markdown(article) |
|
|
| submit_btn.click( |
| inference, |
| input_components, |
| output_components, |
| api_name="predict", |
| scroll_to_output=True, |
| ) |
|
|
| clear_btn.click( |
| None, |
| [], |
| (input_components + output_components + [input_component_column]), |
| _js=f"""() => {json.dumps( |
| [component.cleared_value if hasattr(component, "cleared_value") else None |
| for component in input_components + output_components] + ( |
| [gr.Column.update(visible=True)] |
| ) |
| + ([gr.Column.update(visible=False)]) |
| )} |
| """, |
| ) |
|
|
| demo.launch() |
|
|