| import gradio as gr |
| import transformers |
| import torch |
|
|
| |
| |
|
|
| def initialize_pipeline(): |
| model_id = "joermd/speedy-llama2" |
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| model_id, |
| trust_remote_code=True, |
| use_fast=False |
| ) |
| |
| model = transformers.AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True |
| ) |
| |
| pipeline = transformers.pipeline( |
| "text-generation", |
| model=model, |
| tokenizer=tokenizer, |
| device_map="auto" |
| ) |
| |
| return pipeline, tokenizer |
|
|
| |
| pipeline, tokenizer = initialize_pipeline() |
|
|
| def format_chat_prompt(messages, system_message): |
| """Format the chat messages into a prompt the model can understand""" |
| formatted_messages = [] |
| if system_message: |
| formatted_messages.append({"role": "system", "content": system_message}) |
| |
| for msg in messages: |
| if msg[0]: |
| formatted_messages.append({"role": "user", "content": msg[0]}) |
| if msg[1]: |
| formatted_messages.append({"role": "assistant", "content": msg[1]}) |
| |
| return formatted_messages |
|
|
| def respond( |
| message: str, |
| history: list[tuple[str, str]], |
| system_message: str, |
| max_tokens: int, |
| temperature: float, |
| top_p: float, |
| ): |
| """Generate response using the pipeline""" |
| messages = format_chat_prompt(history, system_message) |
| messages.append({"role": "user", "content": message}) |
| |
| |
| terminators = [ |
| tokenizer.eos_token_id, |
| tokenizer.convert_tokens_to_ids("<|eot_id|>") if "<|eot_id|>" in tokenizer.get_vocab() else None |
| ] |
| terminators = [t for t in terminators if t is not None] |
| |
| outputs = pipeline( |
| messages, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| do_sample=True, |
| eos_token_id=terminators, |
| pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id, |
| ) |
| |
| |
| try: |
| response = outputs[0]["generated_text"] |
| if isinstance(response, list) and len(response) > 0 and isinstance(response[-1], dict): |
| response = response[-1].get("content", "") |
| except (IndexError, KeyError, AttributeError): |
| response = "I apologize, but I couldn't generate a proper response." |
| |
| yield response |
|
|
| |
| demo = gr.ChatInterface( |
| respond, |
| additional_inputs=[ |
| gr.Textbox( |
| value="Kamu adalah seorang asisten yang baik", |
| label="System message" |
| ), |
| gr.Slider( |
| minimum=1, |
| maximum=2048, |
| value=512, |
| step=1, |
| label="Max new tokens" |
| ), |
| gr.Slider( |
| minimum=0.1, |
| maximum=4.0, |
| value=0.7, |
| step=0.1, |
| label="Temperature" |
| ), |
| gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=0.95, |
| step=0.05, |
| label="Top-p (nucleus sampling)" |
| ), |
| ], |
| title="Chat Assistant", |
| description="A conversational AI assistant powered by Llama-2" |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |