| import os |
| from threading import Thread |
| from typing import Iterator |
|
|
| import gradio as gr |
| import spaces |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
|
| DESCRIPTION = """\ |
| # Prot2Text Demo |
| |
| A demo to generate a protein's funtion with its amino acid sequence and its structure using [Prot2Text Base v1.1](https://huggingface.co/habdine/Prot2Text-Base-v1-1). To test this model, only enter below, the AlphaFoldDB ID of the protein. |
| """ |
|
|
| MAX_MAX_NEW_TOKENS = 256 |
| DEFAULT_MAX_NEW_TOKENS = 100 |
|
|
|
|
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| tokenizer = AutoTokenizer.from_pretrained('habdine/Prot2Text-Base-v1-1', |
| trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained('habdine/Prot2Text-Base-v1-1', |
| trust_remote_code=True).to(device) |
| model.eval() |
|
|
|
|
| @spaces.GPU(duration=90) |
| def generate( |
| message: str, |
| chat_history: list[dict], |
| max_new_tokens: int = 1024, |
| do_sample: bool = False, |
| temperature: float = 0.6, |
| top_p: float = 0.9, |
| top_k: int = 50, |
| repetition_penalty: float = 1.2, |
| ) -> Iterator[str]: |
|
|
|
|
| streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) |
| generate_kwargs = dict( |
| protein_pdbID=message, |
| tokenizer=tokenizer, |
| device=device, |
| streamer=streamer, |
| max_new_tokens=max_new_tokens, |
| do_sample=do_sample, |
| top_p=top_p, |
| top_k=top_k, |
| temperature=temperature, |
| num_beams=1, |
| repetition_penalty=repetition_penalty, |
| ) |
| t = Thread(target=model.generate_protein_description, kwargs=generate_kwargs) |
| t.start() |
|
|
| outputs = [] |
| for text in streamer: |
| outputs.append(text) |
| yield "".join(outputs) |
|
|
|
|
| chat_interface = gr.ChatInterface( |
| fn=generate, |
| additional_inputs=[ |
| gr.Slider( |
| label="Max new tokens", |
| minimum=1, |
| maximum=MAX_MAX_NEW_TOKENS, |
| step=1, |
| value=DEFAULT_MAX_NEW_TOKENS, |
| ), |
| gr.Checkbox(label="Do Sample"), |
| gr.Slider( |
| label="Temperature", |
| minimum=0.1, |
| maximum=4.0, |
| step=0.1, |
| value=0.6, |
| ), |
| gr.Slider( |
| label="Top-p (nucleus sampling)", |
| minimum=0.05, |
| maximum=1.0, |
| step=0.05, |
| value=0.9, |
| ), |
| gr.Slider( |
| label="Top-k", |
| minimum=1, |
| maximum=1000, |
| step=1, |
| value=50, |
| ), |
| gr.Slider( |
| label="Repetition penalty", |
| minimum=1.0, |
| maximum=2.0, |
| step=0.05, |
| value=1.0, |
| ), |
| ], |
| stop_btn=None, |
| examples=[ |
| ['P0A0V1'], |
| ["Q10MK9"], |
| ["Q6K5W5"], |
| ["Q65WY8"] |
| |
| ], |
| cache_examples=False, |
| type="messages", |
| ) |
|
|
| with gr.Blocks(css_paths="style.css", fill_height=True) as demo: |
| gr.Markdown(DESCRIPTION) |
| gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") |
| chat_interface.render() |
|
|
| if __name__ == "__main__": |
| demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860) |