| import gradio as gr |
| import spaces |
| import time |
| import torch |
| from PIL import Image |
| from transformers import AutoProcessor, AutoModelForVision2Seq |
| from transformers.image_utils import load_image |
| from typing import List |
| processor = AutoProcessor.from_pretrained("TIGER-Lab/Mantis-8B-Idefics2") |
| model = AutoModelForVision2Seq.from_pretrained("TIGER-Lab/Mantis-8B-Idefics2", torch_dtype=torch.bfloat16) |
|
|
| @spaces.GPU |
| def generate_stream(text:str, images:List[Image.Image], history: List[dict], **kwargs): |
| global processor, model |
| model.to("cuda") |
| if not images: |
| images = None |
| |
| prompt = processor.apply_chat_template(history, add_generation_prompt=True) |
| print("Prompt: ") |
| print(prompt) |
| print("Images: ") |
| print(images) |
| inputs = processor(text=prompt, images=images, return_tensors="pt") |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
| |
| from transformers import TextIteratorStreamer |
| from threading import Thread |
| streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) |
| kwargs["streamer"] = streamer |
| inputs.update(kwargs) |
| thread = Thread(target=model.generate, kwargs=inputs) |
| thread.start() |
| output = "" |
| for _output in streamer: |
| output += _output |
| yield output |
|
|
| def enable_next_image(uploaded_images, image): |
| uploaded_images.append(image) |
| return uploaded_images, gr.MultimodalTextbox(value=None, interactive=False) |
|
|
| def add_message(history, message): |
| if message["files"]: |
| for file in message["files"]: |
| history.append([(file,), None]) |
| if message["text"]: |
| history.append([message["text"], None]) |
| return history, gr.MultimodalTextbox(value=None) |
|
|
| def print_like_dislike(x: gr.LikeData): |
| print(x.index, x.value, x.liked) |
|
|
|
|
| def get_chat_images(history): |
| images = [] |
| for message in history: |
| if isinstance(message[0], tuple): |
| image = load_image(message[0][0]) |
| images.append(image) |
| return images |
|
|
| def get_chat_history(history): |
| |
| images = get_chat_images(history) |
| messages = [] |
| cur_image_idx = 0 |
| for i, message in enumerate(history): |
| if isinstance(message[0], str): |
| num_images = message[0].count("<image>") |
| messages.append( |
| { |
| "role": "user", |
| "content": [] |
| } |
| ) |
| print(num_images, cur_image_idx, len(images)) |
| assert num_images + cur_image_idx <= len(images), f"Number of images uploaded is less than the number of <image> placeholders in the text. Please upload more images." |
| if num_images > 0: |
| split_text = message[0].split("<image>") |
| if split_text[0].strip(): |
| messages[-1]["content"].append({"type": "text", "text": split_text[0].strip()}) |
| for idx in range(num_images): |
| messages[-1]["content"].append({"type": "image"}) |
| if split_text[idx + 1].strip(): |
| messages[-1]["content"].append({"type": "text", "text": split_text[idx + 1].strip()}) |
| else: |
| messages[-1]["content"].append({"type": "text", "text": message[0]}) |
| if message[1]: |
| messages.append( |
| { |
| "role": "assistant", |
| "content": [{"type": "text", "text": message[1]}] |
| } |
| ) |
| elif isinstance(message[0], tuple): |
| pass |
| return messages, images |
|
|
|
|
| def bot(history): |
| cur_messages = {"text": "", "images": []} |
| for message in history[::-1]: |
| if message[1]: |
| break |
| if isinstance(message[0], str): |
| cur_messages["text"] = message[0] + " " + cur_messages["text"] |
| elif isinstance(message[0], tuple): |
| cur_messages["images"].extend(message[0]) |
| cur_messages["text"] = cur_messages["text"].strip() |
| cur_messages["images"] = cur_messages["images"][::-1] |
| if not cur_messages["text"]: |
| raise gr.Error("Please enter a message") |
| if cur_messages['text'].count("<image>") < len(cur_messages['images']): |
| gr.Warning("The number of images uploaded is more than the number of <image> placeholders in the text. Will automatically prepend <image> to the text.") |
| cur_messages['text'] = "<image> "* (len(cur_messages['images']) - cur_messages['text'].count("<image>")) + cur_messages['text'] |
| history[-1][0] = cur_messages["text"] |
| if cur_messages['text'].count("<image>") > len(cur_messages['images']): |
| gr.Warning("The number of images uploaded is less than the number of <image> placeholders in the text. Will automatically remove extra <image> placeholders from the text.") |
| cur_messages['text'] = cur_messages['text'][::-1].replace("<image>"[::-1], "", cur_messages['text'].count("<image>") - len(cur_messages['images']))[::-1] |
| history[-1][0] = cur_messages["text"] |
| |
| chat_history, chat_images = get_chat_history(history) |
| |
| generation_kwargs = { |
| "max_new_tokens": 4096, |
| "num_beams": 1, |
| "do_sample": False |
| } |
| |
| response = generate_stream(None, chat_images, chat_history, **generation_kwargs) |
| for _output in response: |
| history[-1][1] = _output |
| time.sleep(0.05) |
| yield history |
|
|
|
|
| |
| def build_demo(): |
| with gr.Blocks() as demo: |
| |
| gr.Markdown(""" # Mantis |
| Mantis is a multimodal conversational AI model that can chat with users about images and text. It's optimized for multi-image reasoning, where inverleaved text and images can be used to generate responses. |
| |
| ### [Paper](https://arxiv.org/abs/2405.01483) | [Github](https://github.com/TIGER-AI-Lab/Mantis) | [Models](https://huggingface.co/collections/TIGER-Lab/mantis-6619b0834594c878cdb1d6e4) | [Dataset](https://huggingface.co/datasets/TIGER-Lab/Mantis-Instruct) | [Website](https://tiger-ai-lab.github.io/Mantis/) |
| """) |
| |
| gr.Markdown("""## Chat with Mantis |
| Mantis supports interleaved text-image input format, where you can simply use the placeholder `<image>` to indicate the position of uploaded images. |
| The model is optimized for multi-image reasoning, while preserving the ability to chat about text and images in a single conversation. |
| (The model currently serving is [🤗 TIGER-Lab/Mantis-8B-Idefics2](https://huggingface.co/TIGER-Lab/Mantis-8B-Idefics2)) |
| """) |
| |
| chatbot = gr.Chatbot(line_breaks=True) |
| chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload images. Please use <image> to indicate the position of uploaded images", show_label=True) |
| |
| chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input]) |
| |
| """ |
| with gr.Accordion(label='Advanced options', open=False): |
| temperature = gr.Slider( |
| label='Temperature', |
| minimum=0.1, |
| maximum=2.0, |
| step=0.1, |
| value=0.2, |
| interactive=True |
| ) |
| top_p = gr.Slider( |
| label='Top-p', |
| minimum=0.05, |
| maximum=1.0, |
| step=0.05, |
| value=1.0, |
| interactive=True |
| ) |
| """ |
|
|
| bot_msg = chat_msg.success(bot, chatbot, chatbot, api_name="bot_response") |
| |
| chatbot.like(print_like_dislike, None, None) |
|
|
| with gr.Row(): |
| send_button = gr.Button("Send") |
| clear_button = gr.ClearButton([chatbot, chat_input]) |
|
|
| send_button.click( |
| add_message, [chatbot, chat_input], [chatbot, chat_input] |
| ).then( |
| bot, chatbot, chatbot, api_name="bot_response" |
| ) |
| |
| gr.Examples( |
| examples=[ |
| { |
| "text": "<image> <image> <image> Which image shows a different mood of character from the others?", |
| "files": ["./examples/image12.jpg", "./examples/image13.jpg", "./examples/image14.jpg"] |
| }, |
| { |
| "text": "<image> <image> What's the difference between these two images? Please describe as much as you can.", |
| "files": ["./examples/image1.jpg", "./examples/image2.jpg"] |
| }, |
| { |
| "text": "<image> <image> Which image shows an older dog?", |
| "files": ["./examples/image8.jpg", "./examples/image9.jpg"] |
| }, |
| { |
| "text": "Write a description for the given image sequence in a single paragraph, what is happening in this episode?", |
| "files": ["./examples/image3.jpg", "./examples/image4.jpg", "./examples/image5.jpg", "./examples/image6.jpg", "./examples/image7.jpg"] |
| }, |
| { |
| "text": "<image> <image> How many dices are there in image 1 and image 2 respectively?", |
| "files": ["./examples/image10.jpg", "./examples/image15.jpg"] |
| }, |
| ], |
| inputs=[chat_input], |
| ) |
| |
| gr.Markdown(""" |
| ## Citation |
| ``` |
| @article{jiang2024mantis, |
| title={MANTIS: Interleaved Multi-Image Instruction Tuning}, |
| author={Jiang, Dongfu and He, Xuan and Zeng, Huaye and Wei, Con and Ku, Max and Liu, Qian and Chen, Wenhu}, |
| journal={arXiv preprint arXiv:2405.01483}, |
| year={2024} |
| } |
| ```""") |
| return demo |
| |
|
|
| if __name__ == "__main__": |
| demo = build_demo() |
| demo.launch() |