File size: 3,492 Bytes
173a672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

LOCAL_DIR = "c:/Users/Public/CogniXpert-Model-v1.0"
BASE_ID = "unsloth/meta-llama-3.1-8b-bnb-4bit"

@st.cache_resource
def load_model(use_adapter: bool):
    tok = AutoTokenizer.from_pretrained(LOCAL_DIR)
    base = AutoModelForCausalLM.from_pretrained(BASE_ID, device_map="auto")
    cfg_path = os.path.join(LOCAL_DIR, "adapter_config.json")
    safetensors_path = os.path.join(LOCAL_DIR, "adapter_model.safetensors")
    bin_path = os.path.join(LOCAL_DIR, "adapter_model.bin")
    has_config = os.path.exists(cfg_path)
    has_weights = os.path.exists(safetensors_path) or os.path.exists(bin_path)
    if use_adapter and has_config and has_weights:
        base = PeftModel.from_pretrained(base, LOCAL_DIR)
    elif use_adapter and has_config and not has_weights:
        st.warning("LoRA adapter config found but weights missing. Proceeding without adapter.")
    return tok, base

def format_prompt(system_text: str, messages: list[str]):
    s = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n" + system_text + "<|eot_id|>\n"
    content = s
    for i in range(len(messages)):
        if i % 2 == 0:
            content += "<|start_header_id|>user<|end_header_id|>\n" + messages[i] + "<|eot_id|>\n"
        else:
            content += "<|start_header_id|>assistant<|end_header_id|>\n" + messages[i] + "<|eot_id|>\n"
    content += "<|start_header_id|>assistant<|end_header_id|>\n"
    return content

st.set_page_config(page_title="CogniXpert Chat", page_icon="🧠", layout="centered")

if "messages" not in st.session_state:
    st.session_state.messages = []

st.title("CogniXpert Chat")
st.caption("Supportive, safety‑aware conversational AI. Not medical advice.")

use_adapter = st.sidebar.checkbox("Use LoRA adapter if available", value=True)
temperature = st.sidebar.slider("Temperature", 0.0, 1.5, 0.6, 0.05)
top_p = st.sidebar.slider("Top‑p", 0.1, 1.0, 0.9, 0.05)
max_new_tokens = st.sidebar.slider("Max new tokens", 32, 1024, 256, 32)

system_default = "You are CogniXpert, a supportive, safety‑aware assistant. Encourage help‑seeking and evidence‑based coping strategies. Avoid clinical diagnosis or prescriptive treatment."
system_text = st.text_area("System prompt", value=system_default, height=100)

tok, model = load_model(use_adapter)

for i, msg in enumerate(st.session_state.messages):
    role = "assistant" if i % 2 == 1 else "user"
    with st.chat_message(role):
        st.markdown(msg)

user_input = st.chat_input("Type your message")

if user_input:
    st.session_state.messages.append(user_input)
    with st.chat_message("user"):
        st.markdown(user_input)
    prompt = format_prompt(system_text, st.session_state.messages)
    inputs = tok(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
        )
    text = tok.decode(out[0], skip_special_tokens=False)
    key = "<|start_header_id|>assistant<|end_header_id|>"
    idx = text.rfind(key)
    resp = text[idx + len(key):]
    eot = resp.find("<|eot_id|>")
    if eot != -1:
        resp = resp[:eot]
    resp = resp.strip()
    st.session_state.messages.append(resp)
    with st.chat_message("assistant"):
        st.markdown(resp)