import os import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel LOCAL_DIR = "c:/Users/Public/CogniXpert-Model-v1.0" BASE_ID = "unsloth/meta-llama-3.1-8b-bnb-4bit" @st.cache_resource def load_model(use_adapter: bool): tok = AutoTokenizer.from_pretrained(LOCAL_DIR) base = AutoModelForCausalLM.from_pretrained(BASE_ID, device_map="auto") cfg_path = os.path.join(LOCAL_DIR, "adapter_config.json") safetensors_path = os.path.join(LOCAL_DIR, "adapter_model.safetensors") bin_path = os.path.join(LOCAL_DIR, "adapter_model.bin") has_config = os.path.exists(cfg_path) has_weights = os.path.exists(safetensors_path) or os.path.exists(bin_path) if use_adapter and has_config and has_weights: base = PeftModel.from_pretrained(base, LOCAL_DIR) elif use_adapter and has_config and not has_weights: st.warning("LoRA adapter config found but weights missing. Proceeding without adapter.") return tok, base def format_prompt(system_text: str, messages: list[str]): s = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n" + system_text + "<|eot_id|>\n" content = s for i in range(len(messages)): if i % 2 == 0: content += "<|start_header_id|>user<|end_header_id|>\n" + messages[i] + "<|eot_id|>\n" else: content += "<|start_header_id|>assistant<|end_header_id|>\n" + messages[i] + "<|eot_id|>\n" content += "<|start_header_id|>assistant<|end_header_id|>\n" return content st.set_page_config(page_title="CogniXpert Chat", page_icon="🧠", layout="centered") if "messages" not in st.session_state: st.session_state.messages = [] st.title("CogniXpert Chat") st.caption("Supportive, safety‑aware conversational AI. Not medical advice.") use_adapter = st.sidebar.checkbox("Use LoRA adapter if available", value=True) temperature = st.sidebar.slider("Temperature", 0.0, 1.5, 0.6, 0.05) top_p = st.sidebar.slider("Top‑p", 0.1, 1.0, 0.9, 0.05) max_new_tokens = st.sidebar.slider("Max new tokens", 32, 1024, 256, 32) system_default = "You are CogniXpert, a supportive, safety‑aware assistant. Encourage help‑seeking and evidence‑based coping strategies. Avoid clinical diagnosis or prescriptive treatment." system_text = st.text_area("System prompt", value=system_default, height=100) tok, model = load_model(use_adapter) for i, msg in enumerate(st.session_state.messages): role = "assistant" if i % 2 == 1 else "user" with st.chat_message(role): st.markdown(msg) user_input = st.chat_input("Type your message") if user_input: st.session_state.messages.append(user_input) with st.chat_message("user"): st.markdown(user_input) prompt = format_prompt(system_text, st.session_state.messages) inputs = tok(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, ) text = tok.decode(out[0], skip_special_tokens=False) key = "<|start_header_id|>assistant<|end_header_id|>" idx = text.rfind(key) resp = text[idx + len(key):] eot = resp.find("<|eot_id|>") if eot != -1: resp = resp[:eot] resp = resp.strip() st.session_state.messages.append(resp) with st.chat_message("assistant"): st.markdown(resp)