Spaces:
Build error
Build error
Upload 2 files
Browse files- main_v1.py +467 -0
- rag.py +341 -0
main_v1.py
ADDED
|
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import rag # Import the rag module
|
| 4 |
+
import os # Import os for file path handling
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
st.set_page_config(layout="wide", page_title="Communication Board")
|
| 8 |
+
|
| 9 |
+
# --- Session State Initialization ---
|
| 10 |
+
# Check specifically for 'assistant' to ensure it's initialized
|
| 11 |
+
if 'assistant' not in st.session_state:
|
| 12 |
+
st.session_state.current_page = "main"
|
| 13 |
+
st.session_state.show_custom_words = False
|
| 14 |
+
st.session_state.custom_words = []
|
| 15 |
+
st.session_state.text_size = 22
|
| 16 |
+
st.session_state.theme = "Default"
|
| 17 |
+
st.session_state.speech_rate = 1.0
|
| 18 |
+
st.session_state.voice_option = "Default Voice"
|
| 19 |
+
st.session_state.messages = []
|
| 20 |
+
st.session_state.assistant = None
|
| 21 |
+
# Initialize text_output only if assistant is not initialized, assuming they go together
|
| 22 |
+
st.session_state.text_output = ""
|
| 23 |
+
|
| 24 |
+
# --- Theme Colors ---
|
| 25 |
+
theme_colors = {
|
| 26 |
+
"Default": {
|
| 27 |
+
"bg": "#FFFFFF", "text": "#000000",
|
| 28 |
+
"pronoun": "#FFFF99", "verb": "#CCFFCC",
|
| 29 |
+
"question": "#CCCCFF", "common": "#FFCC99",
|
| 30 |
+
"preposition": "#99CCFF", "descriptive": "#CCFF99",
|
| 31 |
+
"misc": "#FFB6C1"
|
| 32 |
+
},
|
| 33 |
+
"High Contrast": {
|
| 34 |
+
"bg": "#FFFFFF", "text": "#000000",
|
| 35 |
+
"pronoun": "#FFFF00", "verb": "#00FF00",
|
| 36 |
+
"question": "#0000FF", "common": "#FF6600",
|
| 37 |
+
"preposition": "#00CCFF", "descriptive": "#66FF33",
|
| 38 |
+
"misc": "#FF3366"
|
| 39 |
+
},
|
| 40 |
+
"Pastel": {
|
| 41 |
+
"bg": "#F8F8FF", "text": "#333333",
|
| 42 |
+
"pronoun": "#FFEFD5", "verb": "#E0FFFF",
|
| 43 |
+
"question": "#D8BFD8", "common": "#FFE4B5",
|
| 44 |
+
"preposition": "#B0E0E6", "descriptive": "#F0FFF0",
|
| 45 |
+
"misc": "#FFF0F5"
|
| 46 |
+
},
|
| 47 |
+
"Dark Mode": {
|
| 48 |
+
"bg": "#333333", "text": "#FFFFFF",
|
| 49 |
+
"pronoun": "#8B8B00", "verb": "#006400",
|
| 50 |
+
"question": "#00008B", "common": "#8B4500",
|
| 51 |
+
"preposition": "#00688B", "descriptive": "#698B22",
|
| 52 |
+
"misc": "#8B1A1A"
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
colors = theme_colors[st.session_state.theme]
|
| 56 |
+
|
| 57 |
+
# --- Helper Function to Initialize Assistant (Adapted from previous main.py) ---
|
| 58 |
+
@st.cache_resource # Cache the assistant
|
| 59 |
+
def initialize_assistant(doc_path):
|
| 60 |
+
"""Initializes the AACAssistant."""
|
| 61 |
+
# Create a dummy document if it doesn't exist for demonstration
|
| 62 |
+
if not os.path.exists(doc_path):
|
| 63 |
+
st.sidebar.warning(f"Doc '{os.path.basename(doc_path)}' not found. Creating dummy.")
|
| 64 |
+
try:
|
| 65 |
+
with open(doc_path, "w") as f:
|
| 66 |
+
f.write("""
|
| 67 |
+
I grew up in Seattle and love the rain.
|
| 68 |
+
My favorite hobby is playing chess.
|
| 69 |
+
I have a dog named Max.
|
| 70 |
+
I studied computer science.
|
| 71 |
+
I enjoy sci-fi movies.
|
| 72 |
+
""")
|
| 73 |
+
except Exception as e:
|
| 74 |
+
st.sidebar.error(f"Failed to create dummy doc: {e}")
|
| 75 |
+
return None
|
| 76 |
+
try:
|
| 77 |
+
assistant = rag.AACAssistant(doc_path)
|
| 78 |
+
st.sidebar.success("AAC Assistant Initialized.")
|
| 79 |
+
return assistant
|
| 80 |
+
except Exception as e:
|
| 81 |
+
st.sidebar.error(f"Error initializing AAC Assistant: {e}")
|
| 82 |
+
st.sidebar.error("Ensure Ollama/LM Studio running.")
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
DEFAULT_DOCUMENT_PATH = "aac_user_experiences.txt"
|
| 86 |
+
|
| 87 |
+
# --- CSS Styling ---
|
| 88 |
+
css = f"""
|
| 89 |
+
<style>
|
| 90 |
+
.big-font {{
|
| 91 |
+
font-size:{st.session_state.text_size}px !important;
|
| 92 |
+
text-align: center;
|
| 93 |
+
}}
|
| 94 |
+
.output-box {{
|
| 95 |
+
border: 2px solid #ddd;
|
| 96 |
+
border-radius: 5px;
|
| 97 |
+
padding: 15px;
|
| 98 |
+
min-height: 100px;
|
| 99 |
+
background-color: white;
|
| 100 |
+
margin-bottom: 15px;
|
| 101 |
+
font-size: {st.session_state.text_size}px;
|
| 102 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
| 103 |
+
}}
|
| 104 |
+
div[data-testid="stHorizontalBlock"] {{
|
| 105 |
+
gap: 5px;
|
| 106 |
+
}}
|
| 107 |
+
section[data-testid="stSidebar"] {{
|
| 108 |
+
width: 20rem;
|
| 109 |
+
background-color: {colors["bg"]};
|
| 110 |
+
}}
|
| 111 |
+
body {{
|
| 112 |
+
background-color: {colors["bg"]};
|
| 113 |
+
color: {colors["text"]};
|
| 114 |
+
}}
|
| 115 |
+
.stButton>button {{
|
| 116 |
+
width: 100%;
|
| 117 |
+
height: 60px;
|
| 118 |
+
font-size: {max(16, st.session_state.text_size - 6)}px;
|
| 119 |
+
font-weight: bold;
|
| 120 |
+
white-space: normal;
|
| 121 |
+
padding: 0px;
|
| 122 |
+
transition: transform 0.1s ease;
|
| 123 |
+
}}
|
| 124 |
+
.stButton>button:hover {{
|
| 125 |
+
filter: brightness(95%);
|
| 126 |
+
transform: scale(1.03);
|
| 127 |
+
box-shadow: 0 2px 3px rgba(0,0,0,0.1);
|
| 128 |
+
}}
|
| 129 |
+
.control-button {{
|
| 130 |
+
height: 80px !important;
|
| 131 |
+
font-size: {max(18, st.session_state.text_size - 4)}px !important;
|
| 132 |
+
}}
|
| 133 |
+
.btn-pronoun {{
|
| 134 |
+
background-color: {colors["pronoun"]} !important;
|
| 135 |
+
border: 1px solid #888 !important;
|
| 136 |
+
}}
|
| 137 |
+
.btn-verb {{
|
| 138 |
+
background-color: {colors["verb"]} !important;
|
| 139 |
+
border: 1px solid #888 !important;
|
| 140 |
+
}}
|
| 141 |
+
.btn-question {{
|
| 142 |
+
background-color: {colors["question"]} !important;
|
| 143 |
+
border: 1px solid #888 !important;
|
| 144 |
+
}}
|
| 145 |
+
.btn-common {{
|
| 146 |
+
background-color: {colors["common"]} !important;
|
| 147 |
+
border: 1px solid #888 !important;
|
| 148 |
+
}}
|
| 149 |
+
.btn-preposition {{
|
| 150 |
+
background-color: {colors["preposition"]} !important;
|
| 151 |
+
border: 1px solid #888 !important;
|
| 152 |
+
}}
|
| 153 |
+
.btn-descriptive {{
|
| 154 |
+
background-color: {colors["descriptive"]} !important;
|
| 155 |
+
border: 1px solid #888 !important;
|
| 156 |
+
}}
|
| 157 |
+
.btn-misc {{
|
| 158 |
+
background-color: {colors["misc"]} !important;
|
| 159 |
+
border: 1px solid #888 !important;
|
| 160 |
+
}}
|
| 161 |
+
/* Sidebar chat styling */
|
| 162 |
+
.sidebar .stChatMessage {{
|
| 163 |
+
background-color: {colors.get('bg', '#FFFFFF')}; /* Use theme background */
|
| 164 |
+
border-radius: 8px;
|
| 165 |
+
}}
|
| 166 |
+
</style>
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
# --- JS for Button Coloring (no delay, no setTimeout) ---
|
| 170 |
+
js = """
|
| 171 |
+
<script>
|
| 172 |
+
function colorButtons() {
|
| 173 |
+
const buttons = document.querySelectorAll('button[id^="key_"]');
|
| 174 |
+
buttons.forEach(button => {
|
| 175 |
+
const id = button.id;
|
| 176 |
+
const parts = id.split('_');
|
| 177 |
+
if (parts.length >= 4) {
|
| 178 |
+
const category = parts[3];
|
| 179 |
+
button.classList.add('btn-' + category);
|
| 180 |
+
}
|
| 181 |
+
});
|
| 182 |
+
}
|
| 183 |
+
document.addEventListener('DOMContentLoaded', colorButtons);
|
| 184 |
+
new MutationObserver(colorButtons).observe(document.body, { childList: true, subtree: true });
|
| 185 |
+
</script>
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
st.markdown(css, unsafe_allow_html=True)
|
| 189 |
+
st.markdown(js, unsafe_allow_html=True)
|
| 190 |
+
|
| 191 |
+
# --- Keyboard Layout ---
|
| 192 |
+
layout = [
|
| 193 |
+
[
|
| 194 |
+
{"word": "I", "category": "pronoun"},
|
| 195 |
+
{"word": "am", "category": "verb"},
|
| 196 |
+
{"word": "how", "category": "question"},
|
| 197 |
+
{"word": "what", "category": "question"},
|
| 198 |
+
{"word": "when", "category": "question"},
|
| 199 |
+
{"word": "where", "category": "question"},
|
| 200 |
+
{"word": "who", "category": "question"},
|
| 201 |
+
{"word": "why", "category": "question"},
|
| 202 |
+
{"word": "That", "category": "pronoun"},
|
| 203 |
+
{"word": "Please", "category": "common"}
|
| 204 |
+
],
|
| 205 |
+
[
|
| 206 |
+
{"word": "me", "category": "pronoun"},
|
| 207 |
+
{"word": "are", "category": "verb"},
|
| 208 |
+
{"word": "is", "category": "verb"},
|
| 209 |
+
{"word": "was", "category": "verb"},
|
| 210 |
+
{"word": "will", "category": "verb"},
|
| 211 |
+
{"word": "help", "category": "verb"},
|
| 212 |
+
{"word": "need", "category": "verb"},
|
| 213 |
+
{"word": "want", "category": "verb"},
|
| 214 |
+
{"word": "thank you", "category": "common"},
|
| 215 |
+
{"word": "sorry", "category": "common"}
|
| 216 |
+
],
|
| 217 |
+
[
|
| 218 |
+
{"word": "my", "category": "pronoun"},
|
| 219 |
+
{"word": "can", "category": "verb"},
|
| 220 |
+
{"word": "A", "category": "misc"},
|
| 221 |
+
{"word": "B", "category": "misc"},
|
| 222 |
+
{"word": "C", "category": "misc"},
|
| 223 |
+
{"word": "D", "category": "misc"},
|
| 224 |
+
{"word": "E", "category": "misc"},
|
| 225 |
+
{"word": "F", "category": "misc"},
|
| 226 |
+
{"word": "G", "category": "misc"},
|
| 227 |
+
{"word": "H", "category": "misc"}
|
| 228 |
+
],
|
| 229 |
+
[
|
| 230 |
+
{"word": "it", "category": "pronoun"},
|
| 231 |
+
{"word": "did", "category": "verb"},
|
| 232 |
+
{"word": "letter_I", "category": "misc", "display": "I"},
|
| 233 |
+
{"word": "J", "category": "misc"},
|
| 234 |
+
{"word": "K", "category": "misc"},
|
| 235 |
+
{"word": "L", "category": "misc"},
|
| 236 |
+
{"word": "M", "category": "misc"},
|
| 237 |
+
{"word": "N", "category": "misc"},
|
| 238 |
+
{"word": "O", "category": "misc"},
|
| 239 |
+
{"word": "P", "category": "misc"}
|
| 240 |
+
],
|
| 241 |
+
[
|
| 242 |
+
{"word": "they", "category": "pronoun"},
|
| 243 |
+
{"word": "do", "category": "verb"},
|
| 244 |
+
{"word": "Q", "category": "misc"},
|
| 245 |
+
{"word": "R", "category": "misc"},
|
| 246 |
+
{"word": "S", "category": "misc"},
|
| 247 |
+
{"word": "T", "category": "misc"},
|
| 248 |
+
{"word": "U", "category": "misc"},
|
| 249 |
+
{"word": "V", "category": "misc"},
|
| 250 |
+
{"word": "W", "category": "misc"},
|
| 251 |
+
{"word": "X", "category": "misc"}
|
| 252 |
+
],
|
| 253 |
+
[
|
| 254 |
+
{"word": "we", "category": "pronoun"},
|
| 255 |
+
{"word": "Y", "category": "misc"},
|
| 256 |
+
{"word": "Z", "category": "misc"},
|
| 257 |
+
{"word": "1", "category": "misc"},
|
| 258 |
+
{"word": "2", "category": "misc"},
|
| 259 |
+
{"word": "3", "category": "misc"},
|
| 260 |
+
{"word": "4", "category": "misc"},
|
| 261 |
+
{"word": "5", "category": "misc"},
|
| 262 |
+
{"word": ".", "category": "misc"},
|
| 263 |
+
{"word": "?", "category": "misc"}
|
| 264 |
+
],
|
| 265 |
+
[
|
| 266 |
+
{"word": "you", "category": "pronoun"},
|
| 267 |
+
{"word": "6", "category": "misc"},
|
| 268 |
+
{"word": "7", "category": "misc"},
|
| 269 |
+
{"word": "8", "category": "misc"},
|
| 270 |
+
{"word": "9", "category": "misc"},
|
| 271 |
+
{"word": "0", "category": "misc"},
|
| 272 |
+
{"word": "-", "category": "misc"},
|
| 273 |
+
{"word": "!", "category": "misc"},
|
| 274 |
+
{"word": ",", "category": "misc"},
|
| 275 |
+
{"word": "SPACE", "category": "misc"}
|
| 276 |
+
],
|
| 277 |
+
[
|
| 278 |
+
{"word": "your", "category": "pronoun"},
|
| 279 |
+
{"word": "like", "category": "verb"},
|
| 280 |
+
{"word": "to", "category": "preposition"},
|
| 281 |
+
{"word": "with", "category": "preposition"},
|
| 282 |
+
{"word": "in", "category": "preposition"},
|
| 283 |
+
{"word": "the", "category": "misc"},
|
| 284 |
+
{"word": "and", "category": "misc"},
|
| 285 |
+
{"word": "but", "category": "misc"},
|
| 286 |
+
{"word": "not", "category": "descriptive"},
|
| 287 |
+
{"word": "yes", "category": "common"}
|
| 288 |
+
]
|
| 289 |
+
]
|
| 290 |
+
|
| 291 |
+
# --- Add Custom Words ---
|
| 292 |
+
if st.session_state.custom_words:
|
| 293 |
+
custom_row = []
|
| 294 |
+
for idx, word_info in enumerate(st.session_state.custom_words):
|
| 295 |
+
word = word_info["word"]
|
| 296 |
+
category = word_info["category"]
|
| 297 |
+
custom_row.append({"word": f"custom_{idx}_{word}", "display": word, "category": category})
|
| 298 |
+
if len(custom_row) == 10:
|
| 299 |
+
layout.append(custom_row)
|
| 300 |
+
custom_row = []
|
| 301 |
+
if custom_row:
|
| 302 |
+
while len(custom_row) < 10:
|
| 303 |
+
custom_row.append({"word": "", "category": "misc"})
|
| 304 |
+
layout.append(custom_row)
|
| 305 |
+
|
| 306 |
+
# --- Initialize Assistant ---
|
| 307 |
+
# Attempt initialization only once or if it failed previously
|
| 308 |
+
if st.session_state.assistant is None: # This check is now safe
|
| 309 |
+
st.session_state.assistant = initialize_assistant(DEFAULT_DOCUMENT_PATH)
|
| 310 |
+
|
| 311 |
+
# --- Output Box (move to top, before keyboard) ---
|
| 312 |
+
st.title("Communication Board")
|
| 313 |
+
# Use st.text_area directly for input and display, bound to session state
|
| 314 |
+
st.session_state.text_output = st.text_area(
|
| 315 |
+
"Compose Message:", value=st.session_state.text_output, height=100, key="main_text_output"
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# --- Keyboard Rendering (no delay, instant update) ---
|
| 319 |
+
send_action = False # Flag for SEND button
|
| 320 |
+
def add_to_output(word):
|
| 321 |
+
if word == "SPACE":
|
| 322 |
+
st.session_state.text_output += " "
|
| 323 |
+
elif word in [".", "?", "!", ",", "-"]:
|
| 324 |
+
st.session_state.text_output += word
|
| 325 |
+
elif word.isdigit() or (len(word) == 1 and word.isalpha()):
|
| 326 |
+
st.session_state.text_output += word
|
| 327 |
+
else:
|
| 328 |
+
if st.session_state.text_output and not st.session_state.text_output.endswith(" "):
|
| 329 |
+
st.session_state.text_output += " "
|
| 330 |
+
st.session_state.text_output += word
|
| 331 |
+
|
| 332 |
+
st.markdown("### Communication Keyboard")
|
| 333 |
+
for row_idx, row in enumerate(layout):
|
| 334 |
+
cols = st.columns(len(row))
|
| 335 |
+
for col_idx, item in enumerate(cols):
|
| 336 |
+
word_info = row[col_idx]
|
| 337 |
+
if "word" not in word_info or word_info["word"] == "":
|
| 338 |
+
continue
|
| 339 |
+
word = word_info["word"]
|
| 340 |
+
category = word_info["category"]
|
| 341 |
+
display = word_info.get("display", word)
|
| 342 |
+
key = f"key_{row_idx}_{col_idx}_{category}_{word}"
|
| 343 |
+
def make_callback(w=word, d=display):
|
| 344 |
+
def cb():
|
| 345 |
+
if w.startswith("custom_") or w.startswith("letter_"):
|
| 346 |
+
add_to_output(d)
|
| 347 |
+
else:
|
| 348 |
+
add_to_output(w)
|
| 349 |
+
return cb
|
| 350 |
+
with cols[col_idx]:
|
| 351 |
+
st.button(
|
| 352 |
+
display if word != "SPACE" else "␣",
|
| 353 |
+
key=key,
|
| 354 |
+
on_click=make_callback()
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# --- Control Buttons ---
|
| 358 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 359 |
+
with col1:
|
| 360 |
+
if st.button("CLEAR", key="clear_btn", help="Clear all text", use_container_width=True, type="primary"):
|
| 361 |
+
st.session_state.text_output = ""
|
| 362 |
+
st.rerun() # Rerun to reflect change
|
| 363 |
+
with col2:
|
| 364 |
+
if st.button("SPEAK", key="speak_btn", help="Speak the current text", use_container_width=True, type="primary"):
|
| 365 |
+
if st.session_state.text_output:
|
| 366 |
+
st.toast(f"Speaking: {st.session_state.text_output}", icon="🔊")
|
| 367 |
+
if st.button("⌫ DEL", key="backspace", help="Delete last character", use_container_width=True):
|
| 368 |
+
if st.session_state.text_output:
|
| 369 |
+
st.session_state.text_output = st.session_state.text_output[:-1]
|
| 370 |
+
with col4:
|
| 371 |
+
if st.button("⌫ WORD", key="backspace_word", help="Delete last word", use_container_width=True):
|
| 372 |
+
if st.session_state.text_output:
|
| 373 |
+
words = st.session_state.text_output.rstrip().split()
|
| 374 |
+
if words:
|
| 375 |
+
words.pop()
|
| 376 |
+
st.session_state.text_output = " ".join(words)
|
| 377 |
+
if words:
|
| 378 |
+
st.session_state.text_output += " "
|
| 379 |
+
with col3: # Use the 3rd column for SEND
|
| 380 |
+
if st.button("SEND", key="send_btn", help="Send message to assistant", use_container_width=True, type="primary"):
|
| 381 |
+
if st.session_state.text_output:
|
| 382 |
+
send_action = True # Set flag to process sending
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
# --- Settings Sidebar ---
|
| 386 |
+
with st.sidebar:
|
| 387 |
+
# --- Settings Section Commented Out ---
|
| 388 |
+
# st.title("Settings")
|
| 389 |
+
# st.subheader("Interface")
|
| 390 |
+
# theme_options = list(theme_colors.keys())
|
| 391 |
+
# theme_index = theme_options.index(st.session_state.theme)
|
| 392 |
+
# new_theme = st.selectbox("Theme", theme_options, index=theme_index)
|
| 393 |
+
# new_text_size = st.slider("Text Size", 12, 36, st.session_state.text_size)
|
| 394 |
+
# if st.button("Apply Settings Changes", type="primary"):
|
| 395 |
+
# changed = False
|
| 396 |
+
# if new_theme != st.session_state.theme:
|
| 397 |
+
# st.session_state.theme = new_theme
|
| 398 |
+
# changed = True
|
| 399 |
+
# if new_text_size != st.session_state.text_size:
|
| 400 |
+
# st.session_state.text_size = new_text_size
|
| 401 |
+
# changed = True
|
| 402 |
+
# if changed:
|
| 403 |
+
# st.rerun()
|
| 404 |
+
# with st.expander("Speech Settings"):
|
| 405 |
+
# speech_rate = st.slider("Rate", 0.5, 2.0, st.session_state.speech_rate, 0.1)
|
| 406 |
+
# voice_options = ["Default Voice", "Female Voice", "Male Voice", "Child Voice"]
|
| 407 |
+
# voice_index = voice_options.index(st.session_state.voice_option)
|
| 408 |
+
# voice = st.selectbox("Voice", voice_options, index=voice_index)
|
| 409 |
+
# if st.button("Apply Speech Settings"):
|
| 410 |
+
# st.session_state.speech_rate = speech_rate
|
| 411 |
+
# st.session_state.voice_option = voice
|
| 412 |
+
# st.subheader("Custom Words")
|
| 413 |
+
# with st.expander("Add New Word"):
|
| 414 |
+
# word = st.text_input("Word")
|
| 415 |
+
# cat = st.selectbox("Category", list(colors.keys()))
|
| 416 |
+
# col1, col2 = st.columns(2)
|
| 417 |
+
# with col1:
|
| 418 |
+
# if st.button("Add", key="add_word"):
|
| 419 |
+
# if word and cat:
|
| 420 |
+
# st.session_state.custom_words.append({"word": word, "category": cat})
|
| 421 |
+
# st.success(f"Added '{word}'")
|
| 422 |
+
# st.rerun()
|
| 423 |
+
# if st.session_state.custom_words:
|
| 424 |
+
# st.write("Current custom words:")
|
| 425 |
+
# words_df = pd.DataFrame([
|
| 426 |
+
# {"Word": w["word"], "Category": w["category"]}
|
| 427 |
+
# for w in st.session_state.custom_words
|
| 428 |
+
# ])
|
| 429 |
+
# st.dataframe(words_df, hide_index=True)
|
| 430 |
+
# with col2:
|
| 431 |
+
# if st.button("Clear All", key="clear_words"):
|
| 432 |
+
# st.session_state.custom_words = []
|
| 433 |
+
# st.success("Words cleared")
|
| 434 |
+
# st.rerun()
|
| 435 |
+
# --- End of Settings Section Commented Out ---
|
| 436 |
+
|
| 437 |
+
st.divider()
|
| 438 |
+
# --- Chat History Display in Sidebar ---
|
| 439 |
+
st.subheader("Conversation")
|
| 440 |
+
chat_container = st.container(height=400) # Fixed height container
|
| 441 |
+
with chat_container:
|
| 442 |
+
for message in st.session_state.messages:
|
| 443 |
+
with st.chat_message(message["role"]):
|
| 444 |
+
st.markdown(message["content"])
|
| 445 |
+
|
| 446 |
+
# --- Process SEND Action ---
|
| 447 |
+
if send_action and st.session_state.assistant:
|
| 448 |
+
user_message = st.session_state.text_output
|
| 449 |
+
st.session_state.messages.append({"role": "user", "content": user_message})
|
| 450 |
+
|
| 451 |
+
# Process with AACAssistant
|
| 452 |
+
try:
|
| 453 |
+
# Get the response from the AACAssistant
|
| 454 |
+
response = st.session_state.assistant.process_query(user_message)
|
| 455 |
+
# Add assistant response to chat history
|
| 456 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
| 457 |
+
except Exception as e:
|
| 458 |
+
error_message = f"An error occurred: {e}"
|
| 459 |
+
st.error(error_message) # Show error in main area
|
| 460 |
+
st.session_state.messages.append({"role": "assistant", "content": f"*Error processing: {error_message}*"})
|
| 461 |
+
|
| 462 |
+
# Clear the board's text area after sending
|
| 463 |
+
st.session_state.text_output = ""
|
| 464 |
+
st.rerun() # Rerun to update chat and clear board
|
| 465 |
+
|
| 466 |
+
if __name__ == "__main__":
|
| 467 |
+
main()
|
rag.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /Users/divyeshpatel/Desktop/archiveWork/rajvi/nlp/rag.py
|
| 2 |
+
# !pip install llama-cpp-python
|
| 3 |
+
|
| 4 |
+
# from llama.cpp import Llama
|
| 5 |
+
#
|
| 6 |
+
# llm = Llama.from_pretrained(
|
| 7 |
+
# repo_id="rdz-falcon/model",
|
| 8 |
+
# filename="unsloth.F16.gguf",
|
| 9 |
+
# )
|
| 10 |
+
|
| 11 |
+
# !pip install langchain
|
| 12 |
+
# !pip install langchain-community
|
| 13 |
+
|
| 14 |
+
# !pip install chromadb
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import torch
|
| 18 |
+
import tempfile
|
| 19 |
+
from langchain.chains import ConversationalRetrievalChain
|
| 20 |
+
from langchain.memory import ConversationBufferMemory
|
| 21 |
+
from langchain_community.document_loaders import TextLoader
|
| 22 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 23 |
+
from langchain_community.vectorstores import Chroma
|
| 24 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 25 |
+
from langchain_community.llms import HuggingFacePipeline
|
| 26 |
+
from langchain.prompts import PromptTemplate
|
| 27 |
+
from langchain_community.llms import Ollama
|
| 28 |
+
from langchain_openai import ChatOpenAI
|
| 29 |
+
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM, BitsAndBytesConfig
|
| 30 |
+
|
| 31 |
+
def setup_document_retriever(document_path):
|
| 32 |
+
# Load documents with the AAC user's personal experiences
|
| 33 |
+
loader = TextLoader(document_path)
|
| 34 |
+
documents = loader.load()
|
| 35 |
+
|
| 36 |
+
# Split documents into chunks
|
| 37 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 38 |
+
chunk_size=1000,
|
| 39 |
+
chunk_overlap=200,
|
| 40 |
+
separators=["\n\n", "\n", " ", ""]
|
| 41 |
+
)
|
| 42 |
+
chunks = text_splitter.split_documents(documents)
|
| 43 |
+
|
| 44 |
+
# Create embeddings
|
| 45 |
+
embeddings = HuggingFaceEmbeddings(
|
| 46 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 47 |
+
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Create a persistent directory for the ChromaDB
|
| 51 |
+
persist_directory = os.path.join(tempfile.gettempdir(), "chroma_db")
|
| 52 |
+
|
| 53 |
+
# Create Chroma vector store
|
| 54 |
+
vectorstore = Chroma.from_documents(
|
| 55 |
+
documents=chunks,
|
| 56 |
+
embedding=embeddings,
|
| 57 |
+
persist_directory=persist_directory
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Persist the database to disk
|
| 61 |
+
vectorstore.persist()
|
| 62 |
+
|
| 63 |
+
return vectorstore
|
| 64 |
+
|
| 65 |
+
def load_emotion_classifier(api_base_url="http://127.0.0.1:1234/v1"):
|
| 66 |
+
"""
|
| 67 |
+
This function configures and returns a LangChain LLM client
|
| 68 |
+
to interact with an OpenAI-compatible API endpoint (like LM Studio).
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
api_base_url (str): The base URL of the OpenAI-compatible API endpoint.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
ChatOpenAI: A LangChain ChatOpenAI instance configured for the API.
|
| 75 |
+
"""
|
| 76 |
+
print(f"=== CONFIGURING LLM CLIENT FOR API: {api_base_url} ===")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
llm = ChatOpenAI(
|
| 80 |
+
openai_api_base=api_base_url,
|
| 81 |
+
openai_api_key="dummy-key", # Required by LangChain, but not used by LM Studio
|
| 82 |
+
temperature=0.7,
|
| 83 |
+
max_tokens=128,
|
| 84 |
+
)
|
| 85 |
+
return llm
|
| 86 |
+
|
| 87 |
+
# --- The following code was commented out or unreachable in the original notebook ---
|
| 88 |
+
# Example code (replace with appropriate code for your model):
|
| 89 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 90 |
+
# model = AutoModelForCausalLM.from_pretrained(model_name)
|
| 91 |
+
# emotion_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
| 92 |
+
|
| 93 |
+
# input_emotion = "excited"
|
| 94 |
+
# input_situation = text # 'text' variable was not defined here in the original notebook
|
| 95 |
+
|
| 96 |
+
# # Format the user message content
|
| 97 |
+
# user_content = f"Emotion: {input_emotion}\nSituation: {input_situation}"
|
| 98 |
+
# # Create the messages list in the standard OpenAI/chat format
|
| 99 |
+
# messages = [
|
| 100 |
+
# # Note: llama-cpp might not explicitly use a system prompt unless provided here
|
| 101 |
+
# # or baked into the chat_format handler. You might need to add:
|
| 102 |
+
# # {"role": "system", "content": "You are an empathetic assistant."},
|
| 103 |
+
# {"role": "user", "content": user_content},
|
| 104 |
+
# ]
|
| 105 |
+
|
| 106 |
+
# # --- 3. Generate the response using create_chat_completion -- This method doesn't exist on ChatOpenAI, use invoke instead ---
|
| 107 |
+
# print("Generating response...")
|
| 108 |
+
# try:
|
| 109 |
+
# response = llm.create_chat_completion( # This should be llm.invoke(messages)
|
| 110 |
+
# messages=messages,
|
| 111 |
+
# max_tokens=128, # Max length of the generated response (adjust as needed)
|
| 112 |
+
# temperature=0.7, # Controls randomness (adjust)
|
| 113 |
+
# # top_p=0.9, # Optional: Nucleus sampling
|
| 114 |
+
# # top_k=40, # Optional: Top-k sampling
|
| 115 |
+
# stop=["<|eot_id|>"], # Crucial: Stop generation when the model outputs the end-of-turn token
|
| 116 |
+
# stream=False, # Set to True to get token-by-token output (like TextStreamer)
|
| 117 |
+
# )
|
| 118 |
+
|
| 119 |
+
# # --- 4. Extract and print the response -- Access response.content with invoke ---
|
| 120 |
+
# if response and 'choices' in response and len(response['choices']) > 0:
|
| 121 |
+
# assistant_message = response['choices'][0]['message']['content']
|
| 122 |
+
# print("\nAssistant Response:")
|
| 123 |
+
# print(assistant_message.strip())
|
| 124 |
+
# print("returning:", assistant_message.strip())
|
| 125 |
+
# return assistant_message.strip()
|
| 126 |
+
# else:
|
| 127 |
+
# print("\nNo response generated or unexpected format.")
|
| 128 |
+
# print("Full response:", response)
|
| 129 |
+
|
| 130 |
+
# return ""
|
| 131 |
+
|
| 132 |
+
# except Exception as e:
|
| 133 |
+
# print(f"\nAn error occurred during generation: {e}")
|
| 134 |
+
# return ""
|
| 135 |
+
# --- End of commented out/unreachable code ---
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def load_generation_model():
|
| 139 |
+
"""Load the specified Ollama model using LangChain."""
|
| 140 |
+
print("=== CONFIGURING OLLAMA GENERATION MODEL ===")
|
| 141 |
+
model_name = "llama3.2" # Your desired Ollama model
|
| 142 |
+
|
| 143 |
+
# Instantiate the Ollama LLM
|
| 144 |
+
try:
|
| 145 |
+
generation_llm = Ollama(
|
| 146 |
+
model=model_name,
|
| 147 |
+
# temperature=0.1
|
| 148 |
+
)
|
| 149 |
+
print(f"Ollama model '{model_name}' configured.")
|
| 150 |
+
except Exception as e:
|
| 151 |
+
print(f"Error configuring Ollama model: {e}")
|
| 152 |
+
print("Please ensure the Ollama server is running and the model is available.")
|
| 153 |
+
raise
|
| 154 |
+
|
| 155 |
+
return generation_llm
|
| 156 |
+
|
| 157 |
+
def create_prompt_templates():
|
| 158 |
+
"""Create prompt templates for the assistant"""
|
| 159 |
+
|
| 160 |
+
template = """
|
| 161 |
+
<|system|>
|
| 162 |
+
You are an AAC (Augmentative and Alternative Communication) user (Elliot) engaging in a conversation. Your responses must reflect factual details provided in your personal context, be empathetic as guided by the emotion analysis, and align naturally with your previous chat history. You will respond directly as the AAC user, speaking in the first person (using "I", "my", "me").
|
| 163 |
+
|
| 164 |
+
**Instructions:**
|
| 165 |
+
1. Understand the question asked by the conversation partner.
|
| 166 |
+
2. Use the provided "Context" to include accurate personal details about your life (Elliot).
|
| 167 |
+
3. Reflect the empathetic tone described in the "Empathetic Response Guidance".
|
| 168 |
+
4. Ensure your response fits logically within the "Chat History".
|
| 169 |
+
5. Keep your response concise, empathetic, and natural.
|
| 170 |
+
6. Ignore the empathetic tone described in the "Empathetic Response Guidance" if it is not related to the conversation.
|
| 171 |
+
|
| 172 |
+
**Context:**
|
| 173 |
+
{context}
|
| 174 |
+
|
| 175 |
+
**Chat History:**
|
| 176 |
+
{chat_history}
|
| 177 |
+
|
| 178 |
+
**Empathetic Response Guidance:**
|
| 179 |
+
{emotion_analysis}</s>
|
| 180 |
+
<|user|>
|
| 181 |
+
The conversation partner asked: "{question}"
|
| 182 |
+
|
| 183 |
+
Please generate your response as the AAC user, following the instructions above.</s>
|
| 184 |
+
<|assistant|>
|
| 185 |
+
|
| 186 |
+
""".strip()
|
| 187 |
+
|
| 188 |
+
PROMPT = PromptTemplate(
|
| 189 |
+
input_variables=["question", "emotion_analysis", "context", "chat_history"],
|
| 190 |
+
template=template,
|
| 191 |
+
)
|
| 192 |
+
print("\n Prompt:", PROMPT)
|
| 193 |
+
return PROMPT
|
| 194 |
+
|
| 195 |
+
class AACAssistant:
|
| 196 |
+
def __init__(self, document_path):
|
| 197 |
+
print("Initializing AAC Assistant...")
|
| 198 |
+
print("Loading document retriever...")
|
| 199 |
+
self.vectorstore = setup_document_retriever(document_path)
|
| 200 |
+
print("Configuring emotion LLM client...")
|
| 201 |
+
# Use the new function to get the client for the API
|
| 202 |
+
self.emotion_llm = load_emotion_classifier() # You can pass a different URL if needed
|
| 203 |
+
print("Loading generation model...")
|
| 204 |
+
self.llm = load_generation_model() # This now loads the Ollama model
|
| 205 |
+
print("Creating prompt templates...")
|
| 206 |
+
self.prompt = create_prompt_templates()
|
| 207 |
+
print("Setting up conversation memory...")
|
| 208 |
+
|
| 209 |
+
# Set up memory for chat history
|
| 210 |
+
self.memory = ConversationBufferMemory(
|
| 211 |
+
memory_key="chat_history",
|
| 212 |
+
return_messages=True,
|
| 213 |
+
output_key="answer",
|
| 214 |
+
# Specify the input key for the memory explicitly
|
| 215 |
+
input_key="question"
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Create retrieval chain (using the main generation LLM)
|
| 219 |
+
self.chain = ConversationalRetrievalChain.from_llm(
|
| 220 |
+
llm=self.llm, # Use the main generation model here
|
| 221 |
+
retriever=self.vectorstore.as_retriever(search_kwargs={'k': 3}),
|
| 222 |
+
memory=self.memory,
|
| 223 |
+
combine_docs_chain_kwargs={"prompt": self.prompt},
|
| 224 |
+
return_source_documents=True,
|
| 225 |
+
verbose=True
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
print("AAC Assistant initialized and ready!")
|
| 229 |
+
|
| 230 |
+
def get_emotion_analysis(self, situation):
|
| 231 |
+
"""
|
| 232 |
+
Gets emotion analysis from the configured emotion LLM API.
|
| 233 |
+
"""
|
| 234 |
+
# Define the prompt structure for the emotion analysis model
|
| 235 |
+
# (Adjust this based on how you prompted your model in LM Studio)
|
| 236 |
+
input_emotion = "excited" # Or determine this dynamically if needed
|
| 237 |
+
user_content = f"Emotion: {input_emotion}\nSituation: {situation}\nGenerate a brief analysis of the user's likely feeling based on the situation."
|
| 238 |
+
|
| 239 |
+
messages = [
|
| 240 |
+
# {"role": "system", "content": "You are an empathetic assistant analyzing emotions."},
|
| 241 |
+
{"role": "user", "content": user_content},
|
| 242 |
+
]
|
| 243 |
+
|
| 244 |
+
print(f"Sending to emotion API: {messages}")
|
| 245 |
+
try:
|
| 246 |
+
# Use the invoke method for ChatOpenAI
|
| 247 |
+
response = self.emotion_llm.invoke(messages)
|
| 248 |
+
# The response object has a 'content' attribute
|
| 249 |
+
analysis = response.content.strip()
|
| 250 |
+
print(f"Received from emotion API: {analysis}")
|
| 251 |
+
return analysis
|
| 252 |
+
except Exception as e:
|
| 253 |
+
print(f"\nAn error occurred during emotion analysis API call: {e}")
|
| 254 |
+
# Fallback or default analysis
|
| 255 |
+
return f"Could not determine emotion (API error: {e})"
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def process_query(self, user_query):
|
| 259 |
+
"""
|
| 260 |
+
Process a query from the conversation partner to the AAC user.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
user_query (str): Question asked by the conversation partner
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
str: Generated response for the AAC user to communicate
|
| 267 |
+
"""
|
| 268 |
+
# Step 1: Get emotion analysis from the LM Studio API via the emotion_llm client
|
| 269 |
+
print(f"Getting emotion analysis for query: '{user_query}'")
|
| 270 |
+
emotion_analysis = self.get_emotion_analysis(user_query)
|
| 271 |
+
print(f"Emotion Analysis Result: {emotion_analysis}")
|
| 272 |
+
|
| 273 |
+
# Step 2: Run the RAG + LLM chain (using the main generation model)
|
| 274 |
+
# The emotion_analysis is now passed into the prompt context
|
| 275 |
+
print("Running main RAG chain...")
|
| 276 |
+
# Use invoke instead of the deprecated __call__
|
| 277 |
+
# Pass inputs as a dictionary matching the chain's expected input keys
|
| 278 |
+
response = self.chain.invoke(
|
| 279 |
+
{"question": user_query, "emotion_analysis": emotion_analysis}
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
return response["answer"]
|
| 283 |
+
|
| 284 |
+
# def run_demo():
|
| 285 |
+
# # Sample personal experiences document path - replace with your actual file
|
| 286 |
+
# document_path = "aac_user_experiences.txt"
|
| 287 |
+
|
| 288 |
+
# # Create a dummy document if it doesn't exist for demonstration
|
| 289 |
+
# # if not os.path.exists(document_path):
|
| 290 |
+
# # with open(document_path, "w") as f:
|
| 291 |
+
# # f.write("""
|
| 292 |
+
# # I grew up in Seattle and love the rain.
|
| 293 |
+
# # My favorite hobby is playing chess, which I've been doing since I was 7 years old.
|
| 294 |
+
# # I have a dog named Max who is a golden retriever.
|
| 295 |
+
# # I went to college at University of Washington and studied computer science.
|
| 296 |
+
# # I enjoy watching sci-fi movies and Star Trek is my favorite series.
|
| 297 |
+
# # I've traveled to Japan twice and love Japanese cuisine.
|
| 298 |
+
# # Music helps me relax, especially classical piano pieces.
|
| 299 |
+
# # I volunteer at the local animal shelter once a month.
|
| 300 |
+
# # """)
|
| 301 |
+
|
| 302 |
+
# # Initialize the assistant
|
| 303 |
+
# assistant = AACAssistant(document_path)
|
| 304 |
+
|
| 305 |
+
# # Interactive demo
|
| 306 |
+
# print("\n===== AAC Communication Assistant Demo =====")
|
| 307 |
+
# print("(Type 'exit' to end the demo)")
|
| 308 |
+
|
| 309 |
+
# while True:
|
| 310 |
+
# try:
|
| 311 |
+
# user_input = input("\nConversation partner says: ")
|
| 312 |
+
# if user_input.lower() == 'exit':
|
| 313 |
+
# break
|
| 314 |
+
|
| 315 |
+
# response = assistant.process_query(user_input)
|
| 316 |
+
# print(f"\nAAC user communicates: {response}")
|
| 317 |
+
# except EOFError: # Handle case where input stream ends unexpectedly
|
| 318 |
+
# print("\nInput stream closed. Exiting demo.")
|
| 319 |
+
# break
|
| 320 |
+
# except KeyboardInterrupt: # Handle Ctrl+C
|
| 321 |
+
# print("\nDemo interrupted by user. Exiting.")
|
| 322 |
+
# break
|
| 323 |
+
# except Exception as e:
|
| 324 |
+
# print(f"\nAn unexpected error occurred: {e}")
|
| 325 |
+
# # Optionally add more specific error handling or logging
|
| 326 |
+
# # Consider whether to break or continue the loop on error
|
| 327 |
+
# break # Exit on error for safety
|
| 328 |
+
|
| 329 |
+
# try:
|
| 330 |
+
# from importlib.metadata import PackageNotFoundError
|
| 331 |
+
# except ImportError:
|
| 332 |
+
# # Define a fallback for older Python versions
|
| 333 |
+
# class PackageNotFoundError(Exception):
|
| 334 |
+
# pass
|
| 335 |
+
|
| 336 |
+
# # Cell 13: Main Execution Block
|
| 337 |
+
# if __name__ == "__main__":
|
| 338 |
+
# run_demo()
|
| 339 |
+
|
| 340 |
+
# # !pip install bitsandbytes -q || echo "bitsandbytes installation failed, will use fp16 precision instead"
|
| 341 |
+
# # pip install -U bitsandbytes
|