| from threading import Thread |
| from typing import Iterator |
|
|
| import torch |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
|
| model_id = 'codellama/CodeLlama-13b-Instruct-hf' |
|
|
| if torch.cuda.is_available(): |
| config = AutoConfig.from_pretrained(model_id) |
| config.pretraining_tp = 1 |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| config=config, |
| torch_dtype=torch.float16, |
| load_in_4bit=True, |
| device_map='auto', |
| use_safetensors=False, |
| ) |
| else: |
| model = None |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
|
| def get_prompt(message: str, chat_history: list[tuple[str, str]], |
| system_prompt: str) -> str: |
| texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n'] |
| |
| do_strip = False |
| for user_input, response in chat_history: |
| user_input = user_input.strip() if do_strip else user_input |
| do_strip = True |
| texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ') |
| message = message.strip() if do_strip else message |
| texts.append(f'{message} [/INST]') |
| return ''.join(texts) |
|
|
|
|
| def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int: |
| prompt = get_prompt(message, chat_history, system_prompt) |
| input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids'] |
| return input_ids.shape[-1] |
|
|
|
|
| def run(message: str, |
| chat_history: list[tuple[str, str]], |
| system_prompt: str, |
| max_new_tokens: int = 1024, |
| temperature: float = 0.1, |
| top_p: float = 0.9, |
| top_k: int = 50) -> Iterator[str]: |
| prompt = get_prompt(message, chat_history, system_prompt) |
| inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda') |
|
|
| streamer = TextIteratorStreamer(tokenizer, |
| timeout=10., |
| skip_prompt=True, |
| skip_special_tokens=True) |
| generate_kwargs = dict( |
| inputs, |
| streamer=streamer, |
| max_new_tokens=max_new_tokens, |
| do_sample=True, |
| top_p=top_p, |
| top_k=top_k, |
| temperature=temperature, |
| num_beams=1, |
| ) |
| t = Thread(target=model.generate, kwargs=generate_kwargs) |
| t.start() |
|
|
| outputs = [] |
| for text in streamer: |
| outputs.append(text) |
| yield ''.join(outputs) |
|
|