| |
| |
|
|
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| import gradio as gr |
| from typing import * |
| import torch |
| import transformers |
|
|
| from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig |
|
|
| tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") |
| model = LlamaForCausalLM.from_pretrained( |
| "decapoda-research/llama-7b-hf", |
| device_map="cpu", |
| ) |
|
|
| def evaluate(question): |
| prompt = f"The conversation between human and AI assistant.\n[|Human|] {question}.\n[|AI|] " |
| inputs = tokenizer(question, return_tensors="pt") |
| input_ids = inputs["input_ids"].cuda() |
| generation_output = model.generate( |
| input_ids=input_ids, |
| generation_config=GenerationConfig( |
| temperature=1, |
| top_p=0.95, |
| num_beams=4, |
| max_context_length_tokens=2048, |
| ), |
| return_dict_in_generate=True, |
| output_scores=True, |
| max_new_tokens=512 |
| ) |
| output = tokenizer.decode(generation_output.sequences[0]).split("[|AI|]")[1] |
| return output |
|
|
|
|
| def generate_prompt_with_history(text:str, history: str, tokenizer, max_length=2048): |
| history = ["\n[|Human|]{}\n[|AI|]{}".format(x[0],x[1]) for x in history] |
| history.append("\n[|Human|]{}\n[|AI|]".format(text)) |
| history_text = "" |
|
|
| for x in history[::-1]: |
| if tokenizer(history_text + x, return_tensors="pt")['input_ids'].size(-1) <= max_length: |
| history_text = x + history_text |
| flag = True |
| if flag: |
| return history_text, tokenizer(history_text, return_tensors="pt") |
| else: |
| return False |
|
|
|
|
| def is_stop_word_or_prefix(s: str, stop_words: list) -> bool: |
| for stop_word in stop_words: |
| if s.endswith(stop_word): |
| return True |
| for i in range(1, len(stop_word)): |
| if s.endswith(stop_word[:i]): |
| return True |
| return False |
|
|
|
|
| def greedy_search(input_ids: torch.Tensor, |
| model: torch.nn.Module, |
| tokenizer: transformers.PreTrainedTokenizer, |
| stop_words: list, |
| max_length: int, |
| temperature: float = 1.0, |
| top_p: float = 1.0, |
| top_k: int = 25) -> Iterator[str]: |
| generated_tokens = [] |
| past_key_values = None |
| current_length = 1 |
| for i in range(max_length): |
| with torch.no_grad(): |
| if past_key_values is None: |
| outputs = model(input_ids) |
| else: |
| outputs = model(input_ids[:, -1:], past_key_values=past_key_values) |
| logits = outputs.logits[:, -1, :] |
| past_key_values = outputs.past_key_values |
|
|
| logits /= temperature |
|
|
| probs = torch.softmax(logits, dim=-1) |
|
|
| probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) |
| probs_sum = torch.cumsum(probs_sort, dim=-1) |
| mask = probs_sum - probs_sort > top_p |
| probs_sort[mask] = 0.0 |
|
|
| probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) |
| next_token = torch.multinomial(probs_sort, num_samples=1) |
| next_token = torch.gather(probs_idx, -1, next_token) |
|
|
| input_ids = torch.cat((input_ids, next_token), dim=-1) |
|
|
| generated_tokens.append(next_token[0].item()) |
| text = tokenizer.decode(generated_tokens) |
|
|
| yield text |
| if any([x in text for x in stop_words]): |
| return |
| @torch.no_grad() |
|
|
|
|
| def predict(text:str, |
| chatbot, |
| history:str = "", |
| top_p:float = 0.95, |
| temperature:float = 1.0, |
| max_length_tokens:int = 512, |
| max_context_length_tokens:int = 2048): |
| if text=="": |
| return "" |
|
|
| inputs = generate_prompt_with_history(text, history, tokenizer, max_length=max_context_length_tokens) |
| prompt,inputs=inputs |
| begin_length = len(prompt) |
| |
| input_ids = inputs["input_ids"].to(chatbot.device) |
| output = [] |
|
|
| for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p): |
| if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False: |
| if "[|Human|]" in x: |
| x = x[:x.index("[|Human|]")].strip() |
| elif "[| Human |]" in x: |
| x = x[:x.index("[| Human |]")].strip() |
| if "[|AI|]" in x: |
| x = x[:x.index("[|AI|]")].strip() |
| x = x.strip(" ") |
| output.append(x) |
| return output[-1] |
|
|
| |
| |
|
|
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| iface = gr.Interface(fn = predict, |
| inputs = "text", |
| outputs = ["text"], |
| title = "Learn with ChadGPT", |
| description = "Ciao!!!") |
|
|
| iface.launch(inline = False) |