| from threading import Thread |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| pipeline, |
| TextIteratorStreamer, |
| ) |
| import torch |
|
|
| from LLM.chat import Chat |
| from baseHandler import BaseHandler |
| from rich.console import Console |
| import logging |
| from nltk import sent_tokenize |
|
|
| logger = logging.getLogger(__name__) |
|
|
| console = Console() |
|
|
|
|
| WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { |
| "en": "english", |
| "fr": "french", |
| "es": "spanish", |
| "zh": "chinese", |
| "ja": "japanese", |
| "ko": "korean", |
| } |
|
|
| class LanguageModelHandler(BaseHandler): |
| """ |
| Handles the language model part. |
| """ |
|
|
| def setup( |
| self, |
| model_name="microsoft/Phi-3-mini-4k-instruct", |
| device="cuda", |
| torch_dtype="float16", |
| gen_kwargs={}, |
| user_role="user", |
| chat_size=1, |
| init_chat_role=None, |
| init_chat_prompt="You are a helpful AI assistant.", |
| ): |
| self.device = device |
| self.torch_dtype = getattr(torch, torch_dtype) |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_name, torch_dtype=torch_dtype, trust_remote_code=True |
| ).to(device) |
| self.pipe = pipeline( |
| "text-generation", model=self.model, tokenizer=self.tokenizer, device=device |
| ) |
| self.streamer = TextIteratorStreamer( |
| self.tokenizer, |
| skip_prompt=True, |
| skip_special_tokens=True, |
| ) |
| self.gen_kwargs = { |
| "streamer": self.streamer, |
| "return_full_text": False, |
| **gen_kwargs, |
| } |
|
|
| self.chat = Chat(chat_size) |
| if init_chat_role: |
| if not init_chat_prompt: |
| raise ValueError( |
| "An initial promt needs to be specified when setting init_chat_role." |
| ) |
| self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) |
| self.user_role = user_role |
|
|
| self.warmup() |
|
|
| def warmup(self): |
| logger.info(f"Warming up {self.__class__.__name__}") |
|
|
| dummy_input_text = "Repeat the word 'home'." |
| dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] |
| warmup_gen_kwargs = { |
| "min_new_tokens": self.gen_kwargs["min_new_tokens"], |
| "max_new_tokens": self.gen_kwargs["max_new_tokens"], |
| **self.gen_kwargs, |
| } |
|
|
| n_steps = 2 |
|
|
| if self.device == "cuda": |
| start_event = torch.cuda.Event(enable_timing=True) |
| end_event = torch.cuda.Event(enable_timing=True) |
| torch.cuda.synchronize() |
| start_event.record() |
|
|
| for _ in range(n_steps): |
| thread = Thread( |
| target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs |
| ) |
| thread.start() |
| for _ in self.streamer: |
| pass |
|
|
| if self.device == "cuda": |
| end_event.record() |
| torch.cuda.synchronize() |
|
|
| logger.info( |
| f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" |
| ) |
|
|
| def process(self, prompt): |
| console.print("infering language model...") |
| console.print(prompt) |
| logger.debug("infering language model...") |
| language_code = None |
| if isinstance(prompt, tuple): |
| prompt, language_code = prompt |
| prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt |
|
|
| self.chat.append({"role": self.user_role, "content": prompt}) |
| thread = Thread( |
| target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs |
| ) |
| thread.start() |
| if self.device == "mps": |
| generated_text = "" |
| for new_text in self.streamer: |
| generated_text += new_text |
| printable_text = generated_text |
| torch.mps.empty_cache() |
| else: |
| generated_text, printable_text = "", "" |
| for new_text in self.streamer: |
| generated_text += new_text |
| printable_text += new_text |
| sentences = sent_tokenize(printable_text) |
| if len(sentences) > 1: |
| yield (sentences[0], language_code) |
| printable_text = new_text |
|
|
| self.chat.append({"role": "assistant", "content": generated_text}) |
|
|
| |
| yield (printable_text, language_code) |
|
|