| """This script refers to the dialogue example of streamlit, the interactive |
| generation code of chatglm2 and transformers. |
| |
| We mainly modified part of the code logic to adapt to the |
| generation of our model. |
| Please refer to these links below for more information: |
| 1. streamlit chat example: |
| https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps |
| 2. chatglm2: |
| https://github.com/THUDM/ChatGLM2-6B |
| 3. transformers: |
| https://github.com/huggingface/transformers |
| Please run with the command `streamlit run path/to/web_demo.py |
| --server.address=0.0.0.0 --server.port 7860`. |
| Using `python path/to/web_demo.py` may cause unknown problems. |
| """ |
| |
| import copy |
| import warnings |
| from dataclasses import asdict, dataclass |
| from typing import Callable, List, Optional |
|
|
| import streamlit as st |
| import torch |
| from torch import nn |
| from transformers.generation.utils import (LogitsProcessorList, |
| StoppingCriteriaList) |
| from transformers.utils import logging |
|
|
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| @dataclass |
| class GenerationConfig: |
| |
| max_length: int = 32768 |
| top_p: float = 0.8 |
| temperature: float = 0.8 |
| do_sample: bool = True |
| repetition_penalty: float = 1.005 |
|
|
|
|
| @torch.inference_mode() |
| def generate_interactive( |
| model, |
| tokenizer, |
| prompt, |
| generation_config: Optional[GenerationConfig] = None, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = None, |
| prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], |
| List[int]]] = None, |
| additional_eos_token_id: Optional[int] = None, |
| **kwargs, |
| ): |
| inputs = tokenizer([prompt], padding=True, return_tensors='pt') |
| input_length = len(inputs['input_ids'][0]) |
| input_ids = inputs['input_ids'] |
| _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] |
| if generation_config is None: |
| generation_config = model.generation_config |
| generation_config = copy.deepcopy(generation_config) |
| model_kwargs = generation_config.update(**kwargs) |
| bos_token_id, eos_token_id = ( |
| generation_config.bos_token_id, |
| generation_config.eos_token_id, |
| ) |
| if isinstance(eos_token_id, int): |
| eos_token_id = [eos_token_id] |
| if additional_eos_token_id is not None: |
| eos_token_id.append(additional_eos_token_id) |
| has_default_max_length = kwargs.get( |
| 'max_length') is None and generation_config.max_length is not None |
| if has_default_max_length and generation_config.max_new_tokens is None: |
| warnings.warn( |
| f"Using 'max_length''s default \ |
| ({repr(generation_config.max_length)}) \ |
| to control the generation length. " |
| 'This behaviour is deprecated and will be removed from the \ |
| config in v5 of Transformers -- we' |
| ' recommend using `max_new_tokens` to control the maximum \ |
| length of the generation.', |
| UserWarning, |
| ) |
| elif generation_config.max_new_tokens is not None: |
| generation_config.max_length = generation_config.max_new_tokens + \ |
| input_ids_seq_length |
| if not has_default_max_length: |
| logger.warn( |
| f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) " |
| f"and 'max_length'(={generation_config.max_length}) seem to " |
| "have been set. 'max_new_tokens' will take precedence. " |
| 'Please refer to the documentation for more information. ' |
| '(https://huggingface.co/docs/transformers/main/' |
| 'en/main_classes/text_generation)', |
| UserWarning, |
| ) |
|
|
| if input_ids_seq_length >= generation_config.max_length: |
| input_ids_string = 'input_ids' |
| logger.warning( |
| f'Input length of {input_ids_string} is {input_ids_seq_length}, ' |
| f"but 'max_length' is set to {generation_config.max_length}. " |
| 'This can lead to unexpected behavior. You should consider' |
| " increasing 'max_new_tokens'.") |
|
|
| |
| logits_processor = logits_processor if logits_processor is not None \ |
| else LogitsProcessorList() |
| stopping_criteria = stopping_criteria if stopping_criteria is not None \ |
| else StoppingCriteriaList() |
|
|
| logits_processor = model._get_logits_processor( |
| generation_config=generation_config, |
| input_ids_seq_length=input_ids_seq_length, |
| encoder_input_ids=input_ids, |
| prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, |
| logits_processor=logits_processor, |
| ) |
|
|
| stopping_criteria = model._get_stopping_criteria( |
| generation_config=generation_config, |
| stopping_criteria=stopping_criteria) |
| logits_warper = model._get_logits_warper(generation_config) |
|
|
| unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) |
| scores = None |
| while True: |
| model_inputs = model.prepare_inputs_for_generation( |
| input_ids, **model_kwargs) |
| |
| outputs = model( |
| **model_inputs, |
| return_dict=True, |
| output_attentions=False, |
| output_hidden_states=False, |
| ) |
|
|
| next_token_logits = outputs.logits[:, -1, :] |
|
|
| |
| next_token_scores = logits_processor(input_ids, next_token_logits) |
| next_token_scores = logits_warper(input_ids, next_token_scores) |
|
|
| |
| probs = nn.functional.softmax(next_token_scores, dim=-1) |
| if generation_config.do_sample: |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
| else: |
| next_tokens = torch.argmax(probs, dim=-1) |
|
|
| |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
| model_kwargs = model._update_model_kwargs_for_generation( |
| outputs, model_kwargs, is_encoder_decoder=False) |
| unfinished_sequences = unfinished_sequences.mul( |
| (min(next_tokens != i for i in eos_token_id)).long()) |
|
|
| output_token_ids = input_ids[0].cpu().tolist() |
| output_token_ids = output_token_ids[input_length:] |
| for each_eos_token_id in eos_token_id: |
| if output_token_ids[-1] == each_eos_token_id: |
| output_token_ids = output_token_ids[:-1] |
| response = tokenizer.decode(output_token_ids) |
|
|
| yield response |
| |
| |
| if unfinished_sequences.max() == 0 or stopping_criteria( |
| input_ids, scores): |
| break |
|
|
|
|
| def on_btn_click(): |
| del st.session_state.messages |
|
|
|
|
| @st.cache_resource |
| def load_model(): |
| model_dir = 'internlm/internlm2-chat-1_8b' |
| model = (AutoModelForCausalLM.from_pretrained( |
| model_dir, |
| trust_remote_code=True).to(torch.bfloat16)) |
| tokenizer = AutoTokenizer.from_pretrained( |
| model_dir, |
| trust_remote_code=True) |
| return model, tokenizer |
|
|
|
|
| def prepare_generation_config(): |
| with st.sidebar: |
| max_length = st.slider('Max Length', |
| min_value=8, |
| max_value=32768, |
| value=32768) |
| top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01) |
| temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01) |
| st.button('Clear Chat History', on_click=on_btn_click) |
|
|
| generation_config = GenerationConfig(max_length=max_length, |
| top_p=top_p, |
| temperature=temperature) |
|
|
| return generation_config |
|
|
|
|
| user_prompt = '<|im_start|>user\n{user}<|im_end|>\n' |
| robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n' |
| cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\ |
| <|im_start|>assistant\n' |
|
|
|
|
| def combine_history(prompt): |
| messages = st.session_state.messages |
| meta_instruction = ('你是一名经验丰富的景区导游,接下来你将回复我提出的有关景区的问题') |
| total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n' |
| for message in messages: |
| cur_content = message['content'] |
| if message['role'] == 'user': |
| cur_prompt = user_prompt.format(user=cur_content) |
| elif message['role'] == 'robot': |
| cur_prompt = robot_prompt.format(robot=cur_content) |
| else: |
| raise RuntimeError |
| total_prompt += cur_prompt |
| total_prompt = total_prompt + cur_query_prompt.format(user=prompt) |
| return total_prompt |
|
|
|
|
| def main(): |
| print('load model begin.') |
| model, tokenizer = load_model() |
| print('load model end.') |
|
|
| st.title('Chatbot') |
|
|
| generation_config = prepare_generation_config() |
|
|
| |
| if 'messages' not in st.session_state: |
| st.session_state.messages = [] |
|
|
| |
| for message in st.session_state.messages: |
| with st.chat_message(message['role'], avatar=message.get('avatar')): |
| st.markdown(message['content']) |
|
|
| |
| if prompt := st.chat_input('请输入你的问题'): |
| |
| with st.chat_message('user'): |
| st.markdown(prompt) |
| real_prompt = combine_history(prompt) |
| |
| st.session_state.messages.append({ |
| 'role': 'user', |
| 'content': prompt, |
| }) |
| with st.chat_message('robot'): |
| message_placeholder = st.empty() |
| for cur_response in generate_interactive( |
| model=model, |
| tokenizer=tokenizer, |
| prompt=real_prompt, |
| additional_eos_token_id=92542, |
| **asdict(generation_config), |
| ): |
| |
| message_placeholder.markdown(cur_response + '▌') |
| message_placeholder.markdown(cur_response) |
| |
| st.session_state.messages.append({ |
| 'role': 'robot', |
| 'content': cur_response, |
| }) |
|
|
|
|
|
|
| if __name__ == '__main__': |
| main() |