| from transformers import AutoModel, AutoTokenizer,AutoModelForCausalLM |
| import streamlit as st |
| from streamlit_chat import message |
| import torch |
|
|
|
|
| st.set_page_config( |
| page_title="ChatGLM-6b ζΌη€Ί", |
| page_icon=":robot:" |
| ) |
|
|
|
|
| @st.cache_resource |
| def get_model(): |
| |
| |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan-13B-Chat", use_fast=False, trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan-13B-Chat", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True) |
| model.generation_config = GenerationConfig.from_pretrained("baichuan-inc/Baichuan-13B-Chat") |
| model = model.eval() |
| return tokenizer, model |
|
|
|
|
| MAX_TURNS = 20 |
| MAX_BOXES = MAX_TURNS * 2 |
|
|
|
|
| def predict(input, max_length, top_p, temperature, history=None): |
| tokenizer, model = get_model() |
| if history is None: |
| history = [] |
|
|
| with container: |
| if len(history) > 0: |
| if len(history)>MAX_BOXES: |
| history = history[-MAX_TURNS:] |
| for i, (query, response) in enumerate(history): |
| message(query, avatar_style="big-smile", key=str(i) + "_user") |
| message(response, avatar_style="bottts", key=str(i)) |
|
|
| message(input, avatar_style="big-smile", key=str(len(history)) + "_user") |
| st.write("AIζ£ε¨εε€:") |
| with st.empty(): |
| for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, |
| temperature=temperature): |
| query, response = history[-1] |
| st.write(response) |
|
|
| return history |
|
|
|
|
| container = st.container() |
|
|
| |
| prompt_text = st.text_area(label="η¨ζ·ε½δ»€θΎε
₯", |
| height = 100, |
| placeholder="θ―·ε¨θΏεΏθΎε
₯ζ¨ηε½δ»€") |
|
|
| max_length = st.sidebar.slider( |
| 'max_length', 0, 4096, 2048, step=1 |
| ) |
| top_p = st.sidebar.slider( |
| 'top_p', 0.0, 1.0, 0.6, step=0.01 |
| ) |
| temperature = st.sidebar.slider( |
| 'temperature', 0.0, 1.0, 0.95, step=0.01 |
| ) |
|
|
| if 'state' not in st.session_state: |
| st.session_state['state'] = [] |
|
|
| if st.button("ει", key="predict"): |
| with st.spinner("AIζ£ε¨ζθοΌθ―·η¨η........"): |
| |
| st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"]) |
|
|