| from threading import Thread |
| from typing import Iterator |
| from transformers import AutoModel, AutoTokenizer, AutoImageProcessor, TextIteratorStreamer |
| from PIL import Image as PILImage |
| import tempfile |
| import torch |
| import gradio as gr |
|
|
|
|
| def get_gradio_demo(model, tokenizer, image_processor) -> gr.Interface: |
|
|
| def get_prompt(message: str, chat_history: list[tuple[str, str]], |
| system_prompt: str) -> str: |
| texts = [f'#instruction: {system_prompt}\n', '#context:\n'] |
| texts += [f"human: {user_input.strip()}\nagent: {response.strip()}\n" for user_input, response in chat_history if isinstance(user_input, str)] |
| texts += [f'human: {message.strip()}'] |
| return ''.join(texts) |
|
|
|
|
| def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int: |
| prompt = get_prompt(message, chat_history, system_prompt) |
| input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids'] |
| return input_ids.shape[-1] |
|
|
|
|
| def run(image: PILImage.Image, |
| message: str, |
| chat_history: list[tuple[str, str]], |
| system_prompt: str, |
| max_new_tokens: int = 192, |
| temperature: float = 0.1, |
| top_p: float = 0.9, |
| top_k: int = 50) -> Iterator[str]: |
| prompt = get_prompt(message, chat_history, system_prompt) |
| patch_images = image_processor([image], return_tensors="pt").pixel_values.to(torch.float16).to('cuda') |
| inputs = tokenizer([prompt], return_tensors='pt').to('cuda') |
|
|
| streamer = TextIteratorStreamer(tokenizer, timeout=10.) |
| generate_kwargs = dict( |
| inputs, |
| patch_images=patch_images, |
| streamer=streamer, |
| max_length=max_new_tokens, |
| do_sample=True, |
| top_p=top_p, |
| top_k=top_k, |
| temperature=temperature, |
| num_beams=1, |
| ) |
| t = Thread(target=model.generate, kwargs=generate_kwargs) |
| t.start() |
|
|
| outputs = [] |
| for text in streamer: |
| outputs.append(text) |
| yield ''.join(outputs).replace("not yet.", "").replace("<s>", "").replace("</s>", "").strip() |
|
|
| |
|
|
| DEFAULT_SYSTEM_PROMPT = """can you specify which region the context describes?""" |
| MAX_MAX_NEW_TOKENS = 512 |
| DEFAULT_MAX_NEW_TOKENS = 128 |
| MAX_INPUT_TOKEN_LENGTH = 512 |
|
|
| DESCRIPTION = """<h1 align="center">TiO Demo</h1> |
| <div align="center">https://huggingface.co/jxu124/TiO</div> |
| """ |
|
|
| LICENSE = """ |
| <p/> |
| |
| --- |
| """ |
|
|
| if not torch.cuda.is_available(): |
| DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>' |
|
|
|
|
| def upload_image(file_obj): |
| chatbot = [[(file_obj.name,), None]] |
| return (gr.update(visible=False), gr.update(interactive=True, placeholder='Type a message...',), chatbot) |
|
|
|
|
| def clear_and_save_textbox(message: str) -> tuple[str, str]: |
| return '', message |
|
|
|
|
| def display_input(message: str, |
| history: list[tuple[str, str]]) -> list[tuple[str, str]]: |
| if len(history) == 0: |
| raise gr.Error(f'Upload an image first and try again.') |
| history.append((message, '')) |
| return history |
|
|
|
|
| def delete_prev_fn( |
| history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]: |
| try: |
| message, _ = history.pop() |
| if not isinstance(message, str): |
| message, _ = history.pop() |
| except IndexError: |
| message = '' |
| return history, message or '' |
|
|
|
|
| def generate( |
| message: str, |
| history_with_input: list[tuple[str, str]], |
| system_prompt: str, |
| max_new_tokens: int, |
| temperature: float, |
| top_p: float, |
| top_k: int, |
| ) -> Iterator[list[tuple[str, str]]]: |
| if max_new_tokens > MAX_MAX_NEW_TOKENS: |
| raise ValueError |
|
|
| image = PILImage.open(history_with_input[0][0][0]) |
| history = history_with_input[:-1] |
| generator = run(image, message, history, system_prompt, max_new_tokens, temperature, top_p, top_k) |
| try: |
| first_response = next(generator) |
| yield history + [(message, first_response)] |
| except StopIteration: |
| yield history + [(message, '')] |
| for response in generator: |
| if "region:" in response: |
| bboxes = model.utils.sbbox_to_bbox(response) |
| if len(bboxes): |
| with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: |
| model.utils.show_mask(image, bboxes).save(f) |
| chatbot = history + [(message, "OK, I see."), (None, (f.name,))] |
| else: |
| chatbot = history + [(message, response)] |
| yield chatbot |
|
|
|
|
| def process_example(message: str) -> tuple[str, list[tuple[str, str]]]: |
| generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 192, 1, 0.95, 50) |
| for x in generator: |
| pass |
| return '', x |
|
|
|
|
| def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None: |
| input_token_length = get_input_token_length(message, chat_history[:-1], system_prompt) |
| if input_token_length > MAX_INPUT_TOKEN_LENGTH: |
| raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.') |
|
|
|
|
| with gr.Blocks() as demo: |
| gr.Markdown(DESCRIPTION) |
|
|
| with gr.Group(): |
| chatbot = gr.Chatbot(label='Chatbot') |
| imagebox = gr.File( |
| file_types=["image"], |
| show_label=False, |
| ) |
| with gr.Row(): |
| textbox = gr.Textbox( |
| container=False, |
| show_label=False, |
| interactive=False, |
| placeholder='Upload an image...', |
| scale=10, |
| ) |
| submit_button = gr.Button('Submit', |
| variant='primary', |
| scale=1, |
| min_width=0) |
| with gr.Row(): |
| retry_button = gr.Button('🔄 Retry', variant='secondary') |
| undo_button = gr.Button('↩️ Undo', variant='secondary') |
| clear_button = gr.Button('🗑️ Clear', variant='secondary') |
|
|
| saved_input = gr.State() |
|
|
| with gr.Accordion(label='Advanced options', open=False): |
| system_prompt = gr.Textbox(label='System prompt', |
| value=DEFAULT_SYSTEM_PROMPT, |
| lines=6) |
| max_new_tokens = gr.Slider( |
| label='Max new tokens', |
| minimum=1, |
| maximum=MAX_MAX_NEW_TOKENS, |
| step=1, |
| value=DEFAULT_MAX_NEW_TOKENS, |
| ) |
| temperature = gr.Slider( |
| label='Temperature', |
| minimum=0.1, |
| maximum=4.0, |
| step=0.1, |
| value=0.5, |
| ) |
| top_p = gr.Slider( |
| label='Top-p (nucleus sampling)', |
| minimum=0.05, |
| maximum=1.0, |
| step=0.05, |
| value=0.9, |
| ) |
| top_k = gr.Slider( |
| label='Top-k', |
| minimum=1, |
| maximum=1000, |
| step=1, |
| value=20, |
| ) |
|
|
| gr.Markdown(LICENSE) |
| imagebox.upload( |
| fn=upload_image, |
| inputs=imagebox, |
| outputs=[imagebox, textbox, chatbot], |
| api_name=None, |
| queue=False, |
| ) |
|
|
| textbox.submit( |
| fn=clear_and_save_textbox, |
| inputs=textbox, |
| outputs=[textbox, saved_input], |
| api_name=None, |
| queue=False, |
| ).then( |
| fn=display_input, |
| inputs=[saved_input, chatbot], |
| outputs=chatbot, |
| api_name=None, |
| queue=False, |
| ).then( |
| fn=check_input_token_length, |
| inputs=[saved_input, chatbot, system_prompt], |
| api_name=None, |
| queue=False, |
| ).success( |
| fn=generate, |
| inputs=[ |
| saved_input, |
| chatbot, |
| system_prompt, |
| max_new_tokens, |
| temperature, |
| top_p, |
| top_k, |
| ], |
| outputs=chatbot, |
| api_name="generate", |
| ) |
|
|
| button_event_preprocess = submit_button.click( |
| fn=clear_and_save_textbox, |
| inputs=textbox, |
| outputs=[textbox, saved_input], |
| api_name=None, |
| queue=False, |
| ).then( |
| fn=display_input, |
| inputs=[saved_input, chatbot], |
| outputs=chatbot, |
| api_name=None, |
| queue=False, |
| ).then( |
| fn=check_input_token_length, |
| inputs=[saved_input, chatbot, system_prompt], |
| api_name=None, |
| queue=False, |
| ).success( |
| fn=generate, |
| inputs=[ |
| saved_input, |
| chatbot, |
| system_prompt, |
| max_new_tokens, |
| temperature, |
| top_p, |
| top_k, |
| ], |
| outputs=chatbot, |
| api_name=None, |
| ) |
|
|
| retry_button.click( |
| fn=delete_prev_fn, |
| inputs=chatbot, |
| outputs=[chatbot, saved_input], |
| api_name=None, |
| queue=False, |
| ).then( |
| fn=display_input, |
| inputs=[saved_input, chatbot], |
| outputs=chatbot, |
| api_name=None, |
| queue=False, |
| ).then( |
| fn=generate, |
| inputs=[ |
| saved_input, |
| chatbot, |
| system_prompt, |
| max_new_tokens, |
| temperature, |
| top_p, |
| top_k, |
| ], |
| outputs=chatbot, |
| api_name=None, |
| ) |
|
|
| undo_button.click( |
| fn=delete_prev_fn, |
| inputs=chatbot, |
| outputs=[chatbot, saved_input], |
| api_name=None, |
| queue=False, |
| ).then( |
| fn=lambda x: x, |
| inputs=[saved_input], |
| outputs=textbox, |
| api_name=None, |
| queue=False, |
| ) |
|
|
| clear_button.click( |
| fn=lambda: ([], '', gr.update(value=None, visible=True), gr.update(interactive=False, placeholder='Upload an image...',)), |
| outputs=[chatbot, saved_input, imagebox, textbox], |
| queue=False, |
| api_name=None, |
| ) |
|
|
| return demo |
|
|
|
|
| def main(model_id: str = 'jxu124/TiO', host: str = "0.0.0.0", port: int = None): |
| assert torch.cuda.is_available() |
| model = AutoModel.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float16).cuda() |
| tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) |
| image_processor = AutoImageProcessor.from_pretrained(model_id) |
|
|
| |
| model.get_gradio_demo(tokenizer, image_processor).queue(max_size=20).launch(server_name=host, server_port=port) |
|
|
|
|
| if __name__ == "__main__": |
| import fire |
| fire.Fire(main) |
|
|