| import streamlit as st |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
|
| |
| st.set_page_config( |
| page_title="Apertus-8B Chat", |
| page_icon="🤖", |
| layout="wide" |
| ) |
|
|
| |
| st.title("🤖 Chat with Apertus-8B-Instruct") |
| st.caption("A Streamlit app running swiss-ai/Apertus-8B-Instruct-2509") |
|
|
| |
| @st.cache_resource |
| def load_model(): |
| """Loads the model and tokenizer with 4-bit quantization.""" |
| model_id = "swiss-ai/Apertus-8B-Instruct-2509" |
| |
| |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16 |
| ) |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| quantization_config=bnb_config, |
| device_map="auto", |
| ) |
| return tokenizer, model |
|
|
| |
| with st.spinner("Loading Apertus-8B model... This might take a moment."): |
| tokenizer, model = load_model() |
|
|
| |
| |
| if "messages" not in st.session_state: |
| st.session_state.messages = [] |
|
|
| |
| for message in st.session_state.messages: |
| with st.chat_message(message["role"]): |
| st.markdown(message["content"]) |
|
|
| |
| if prompt := st.chat_input("What would you like to ask?"): |
| |
| st.session_state.messages.append({"role": "user", "content": prompt}) |
| |
| with st.chat_message("user"): |
| st.markdown(prompt) |
|
|
| |
| with st.chat_message("assistant"): |
| with st.spinner("Thinking..."): |
| |
| input_ids = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
| |
| outputs = model.generate( |
| **input_ids, |
| max_new_tokens=256, |
| do_sample=True, |
| temperature=0.7, |
| top_k=50, |
| top_p=0.95 |
| ) |
| |
| |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| cleaned_response = response.replace(prompt, "").strip() |
|
|
| st.markdown(cleaned_response) |
| |
| |
| st.session_state.messages.append({"role": "assistant", "content": cleaned_response}) |