File size: 3,492 Bytes
173a672 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | import os
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
LOCAL_DIR = "c:/Users/Public/CogniXpert-Model-v1.0"
BASE_ID = "unsloth/meta-llama-3.1-8b-bnb-4bit"
@st.cache_resource
def load_model(use_adapter: bool):
tok = AutoTokenizer.from_pretrained(LOCAL_DIR)
base = AutoModelForCausalLM.from_pretrained(BASE_ID, device_map="auto")
cfg_path = os.path.join(LOCAL_DIR, "adapter_config.json")
safetensors_path = os.path.join(LOCAL_DIR, "adapter_model.safetensors")
bin_path = os.path.join(LOCAL_DIR, "adapter_model.bin")
has_config = os.path.exists(cfg_path)
has_weights = os.path.exists(safetensors_path) or os.path.exists(bin_path)
if use_adapter and has_config and has_weights:
base = PeftModel.from_pretrained(base, LOCAL_DIR)
elif use_adapter and has_config and not has_weights:
st.warning("LoRA adapter config found but weights missing. Proceeding without adapter.")
return tok, base
def format_prompt(system_text: str, messages: list[str]):
s = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n" + system_text + "<|eot_id|>\n"
content = s
for i in range(len(messages)):
if i % 2 == 0:
content += "<|start_header_id|>user<|end_header_id|>\n" + messages[i] + "<|eot_id|>\n"
else:
content += "<|start_header_id|>assistant<|end_header_id|>\n" + messages[i] + "<|eot_id|>\n"
content += "<|start_header_id|>assistant<|end_header_id|>\n"
return content
st.set_page_config(page_title="CogniXpert Chat", page_icon="🧠", layout="centered")
if "messages" not in st.session_state:
st.session_state.messages = []
st.title("CogniXpert Chat")
st.caption("Supportive, safety‑aware conversational AI. Not medical advice.")
use_adapter = st.sidebar.checkbox("Use LoRA adapter if available", value=True)
temperature = st.sidebar.slider("Temperature", 0.0, 1.5, 0.6, 0.05)
top_p = st.sidebar.slider("Top‑p", 0.1, 1.0, 0.9, 0.05)
max_new_tokens = st.sidebar.slider("Max new tokens", 32, 1024, 256, 32)
system_default = "You are CogniXpert, a supportive, safety‑aware assistant. Encourage help‑seeking and evidence‑based coping strategies. Avoid clinical diagnosis or prescriptive treatment."
system_text = st.text_area("System prompt", value=system_default, height=100)
tok, model = load_model(use_adapter)
for i, msg in enumerate(st.session_state.messages):
role = "assistant" if i % 2 == 1 else "user"
with st.chat_message(role):
st.markdown(msg)
user_input = st.chat_input("Type your message")
if user_input:
st.session_state.messages.append(user_input)
with st.chat_message("user"):
st.markdown(user_input)
prompt = format_prompt(system_text, st.session_state.messages)
inputs = tok(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
text = tok.decode(out[0], skip_special_tokens=False)
key = "<|start_header_id|>assistant<|end_header_id|>"
idx = text.rfind(key)
resp = text[idx + len(key):]
eot = resp.find("<|eot_id|>")
if eot != -1:
resp = resp[:eot]
resp = resp.strip()
st.session_state.messages.append(resp)
with st.chat_message("assistant"):
st.markdown(resp)
|