BeRU Deployer commited on
Commit
dec533d
Β·
1 Parent(s): 8357835

Deploy BeRU Streamlit RAG System - Add app, models logic, configs, and optimizations for HF Spaces

Browse files
Files changed (10) hide show
  1. .hfignore +7 -0
  2. .streamlit/config.toml +21 -0
  3. Dockerfile +49 -0
  4. README.md +47 -6
  5. app.py +282 -0
  6. down.py +898 -0
  7. frontend.html +1075 -0
  8. requirements.txt +19 -0
  9. spaces_app.py +229 -0
  10. vlm2rag2.py +1354 -0
.hfignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # ignore local model and index directories when pushing to Hugging Face
2
+ models/
3
+ VLM2Vec-V2rag3/
4
+ faiss_index/
5
+ venv*
6
+ *.log
7
+ rag.log
.streamlit/config.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [global]
2
+ # Page customization
3
+ analyticsEnabled = false
4
+ logLevel = "info"
5
+
6
+ [client]
7
+ # Faster loading
8
+ showErrorDetails = true
9
+ toolbarMode = "minimal"
10
+
11
+ [server]
12
+ # HF Spaces optimizations
13
+ port = 7860
14
+ headless = true
15
+ runOnSave = false
16
+ maxUploadSize = 200
17
+ enableCORS = false
18
+ enableXsrfProtection = true
19
+
20
+ # Memory management
21
+ maxCachedMessageSize = 2
Dockerfile ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies (for PDF processing, OCR, etc.)
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ git \
9
+ libpoppler-cpp-dev \
10
+ poppler-utils \
11
+ tesseract-ocr \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Copy requirements
15
+ COPY requirements.txt .
16
+
17
+ # Install Python dependencies
18
+ RUN pip install --no-cache-dir -r requirements.txt
19
+
20
+ # Copy application code
21
+ COPY down.py .
22
+ COPY frontend.html .
23
+ COPY vlm2rag2.py .
24
+ COPY check_user.py .
25
+ # include the Streamlit demo as an alternative entrypoint
26
+ COPY app.py .
27
+
28
+ # Create necessary directories
29
+ RUN mkdir -p /app/.cache /app/models /app/VLM2Vec-V2rag3
30
+
31
+ # Expose HF Spaces default port
32
+ EXPOSE 7860
33
+
34
+ # by default the image will run the FastAPI server; to start the Streamlit UI
35
+ to test locally you can override the command:
36
+ # docker run -p7860:7860 <image> streamlit run app.py --server.port=7860
37
+
38
+ # Set environment variables
39
+ ENV PYTHONUNBUFFERED=1
40
+ ENV HF_HUB_DISABLE_SYMLINKS_WARNING=1
41
+
42
+ # default paths used by down.py (can be overridden at runtime)
43
+ ENV MODEL_DIR=/app/models
44
+ ENV LLM_MODEL_PATH=/app/models/Mistral-7B-Instruct-v0.3
45
+ ENV EMBED_MODEL_PATH=/app/models/VLM2Vec-Qwen2VL-2B
46
+ ENV FAISS_INDEX_PATH=/app/VLM2Vec-V2rag3
47
+
48
+ # Run FastAPI app on port 7860
49
+ CMD ["python", "down.py", "--port", "7860"]
README.md CHANGED
@@ -1,13 +1,54 @@
1
  ---
2
- title: BeRu
3
- emoji: 🐨
4
  colorFrom: indigo
5
  colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 6.9.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: Refinary assistance
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: BeRU Chat - RAG Assistant
3
+ emoji: πŸ€–
4
  colorFrom: indigo
5
  colorTo: yellow
6
+ sdk: streamlit
 
7
  app_file: app.py
8
  pinned: false
9
+ short_description: 100% Offline RAG System with Mistral 7B and VLM2Vec
10
  ---
11
 
12
+ # πŸ€– BeRU Chat - RAG Assistant
13
+
14
+ A powerful **100% offline Retrieval-Augmented Generation (RAG) system** combining Mistral 7B LLM with VLM2Vec embeddings for intelligent document search and conversation.
15
+
16
+ ## ✨ Features
17
+
18
+ - πŸ”’ **100% Offline Operation** - No internet required after startup
19
+ - 🧠 **Advanced RAG Architecture**
20
+ - Hybrid retrieval (Vector + BM25 keyword search)
21
+ - Ensemble retriever combining multiple strategies
22
+ - Re-ranking with FlashRank for relevance
23
+ - Multi-turn conversation with history awareness
24
+ - ⚑ **Optimized Performance**
25
+ - 4-bit quantization with BitsAndBytes
26
+ - Flash Attention 2 support
27
+ - FAISS vector indexing
28
+ - πŸ“š **Source Citations** - Every answer cites original sources
29
+
30
+ ## 🎯 Models Used
31
+
32
+ | Component | Model | Details |
33
+ |-----------|-------|---------|
34
+ | **LLM** | Mistral-7B-Instruct-v0.3 | 7B parameters |
35
+ | **Embedding** | VLM2Vec-Qwen2VL-2B | 2B parameters |
36
+ | **Vector Store** | FAISS | Meta's similarity search |
37
+
38
+ ## πŸš€ Getting Started
39
+
40
+ 1. **Wait for Models** - First load takes 5-8 minutes (models download from HF Hub)
41
+ 2. **Upload Documents** - Add PDFs or text files for RAG
42
+ 3. **Ask Questions** - Chat with context-aware answers
43
+ 4. **Get Sources** - Each answer includes citations
44
+
45
+ ## πŸ’» System Requirements
46
+
47
+ - **GPU**: A10G (24GB VRAM) recommended
48
+ - **RAM**: 16GB minimum
49
+ - **Cold Start**: ~5-8 minutes (first time)
50
+ - **Runtime**: Streamlit app on port 7860
51
+
52
+ ## πŸ“– Documentation
53
+
54
+ For more information, visit the [GitHub repository](https://github.com/AnwinJosy/BeRU)
app.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import os
4
+ import pickle
5
+ import faiss
6
+ import numpy as np
7
+ from transformers import AutoModel, AutoProcessor, AutoTokenizer
8
+ from typing import List, Dict
9
+ import time
10
+
11
+ # ========================================
12
+ # 🎨 STREAMLIT PAGE CONFIG
13
+ # ========================================
14
+ st.set_page_config(
15
+ page_title="BeRU Chat - RAG Assistant",
16
+ page_icon="πŸ€–",
17
+ layout="wide",
18
+ initial_sidebar_state="expanded"
19
+ )
20
+
21
+ # ========================================
22
+ # 🎯 CACHING FOR MODEL LOADING
23
+ # ========================================
24
+ @st.cache_resource
25
+ def load_embedding_model():
26
+ """Load VLM2Vec embedding model"""
27
+ st.write("⏳ Loading embedding model...")
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+
30
+ model = AutoModel.from_pretrained(
31
+ "TIGER-Lab/VLM2Vec-Qwen2VL-2B",
32
+ trust_remote_code=True,
33
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
34
+ ).to(device)
35
+
36
+ processor = AutoProcessor.from_pretrained(
37
+ "TIGER-Lab/VLM2Vec-Qwen2VL-2B",
38
+ trust_remote_code=True
39
+ )
40
+
41
+ tokenizer = AutoTokenizer.from_pretrained(
42
+ "TIGER-Lab/VLM2Vec-Qwen2VL-2B",
43
+ trust_remote_code=True
44
+ )
45
+
46
+ model.eval()
47
+ st.success("βœ… Embedding model loaded!")
48
+ return model, processor, tokenizer, device
49
+
50
+ @st.cache_resource
51
+ def load_llm_model():
52
+ """Load Mistral 7B LLM"""
53
+ st.write("⏳ Loading language model...")
54
+ device = "cuda" if torch.cuda.is_available() else "cpu"
55
+
56
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig
57
+
58
+ # 4-bit quantization config for efficiency
59
+ quantization_config = BitsAndBytesConfig(
60
+ load_in_4bit=True,
61
+ bnb_4bit_compute_dtype=torch.float16,
62
+ bnb_4bit_use_double_quant=True,
63
+ bnb_4bit_quant_type="nf4"
64
+ )
65
+
66
+ model = AutoModelForCausalLM.from_pretrained(
67
+ "mistralai/Mistral-7B-Instruct-v0.3",
68
+ quantization_config=quantization_config,
69
+ device_map="auto"
70
+ )
71
+
72
+ tokenizer = AutoTokenizer.from_pretrained(
73
+ "mistralai/Mistral-7B-Instruct-v0.3"
74
+ )
75
+
76
+ st.success("βœ… Language model loaded!")
77
+ return model, tokenizer, device
78
+
79
+ @st.cache_resource
80
+ def load_faiss_index():
81
+ """Load FAISS index if exists"""
82
+ if os.path.exists("VLM2Vec-V2rag2/text_index.faiss"):
83
+ st.write("⏳ Loading FAISS index...")
84
+ index = faiss.read_index("VLM2Vec-V2rag2/text_index.faiss")
85
+ st.success("βœ… FAISS index loaded!")
86
+ return index
87
+ else:
88
+ st.warning("⚠️ FAISS index not found. Please build the index first.")
89
+ return None
90
+
91
+ # ========================================
92
+ # πŸ’¬ EMBEDDING & RETRIEVAL FUNCTIONS
93
+ # ========================================
94
+ def get_embeddings(texts: List[str], model, processor, tokenizer, device) -> np.ndarray:
95
+ """Generate embeddings for texts"""
96
+ embeddings_list = []
97
+
98
+ for text in texts:
99
+ inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device)
100
+
101
+ with torch.no_grad():
102
+ outputs = model(**inputs, output_hidden_states=True)
103
+ embedding = outputs.hidden_states[-1].mean(dim=1).cpu().numpy()
104
+
105
+ embeddings_list.append(embedding.flatten())
106
+
107
+ return np.array(embeddings_list)
108
+
109
+ def retrieve_documents(query: str, model, processor, tokenizer, device, faiss_index, k: int = 5) -> List[Dict]:
110
+ """Retrieve relevant documents using FAISS"""
111
+ if faiss_index is None:
112
+ return []
113
+
114
+ # Get query embedding
115
+ query_embedding = get_embeddings([query], model, processor, tokenizer, device)
116
+
117
+ # Search FAISS index
118
+ distances, indices = faiss_index.search(query_embedding, k)
119
+
120
+ # Load documents metadata (assuming you have this stored)
121
+ results = []
122
+ for idx in indices[0]:
123
+ if idx >= 0:
124
+ results.append({
125
+ "index": idx,
126
+ "distance": float(distances[0][list(indices[0]).index(idx)])
127
+ })
128
+
129
+ return results
130
+
131
+ def generate_response(query: str, context: str, model, tokenizer, device) -> str:
132
+ """Generate response using Mistral"""
133
+
134
+ prompt = f"""[INST] You are a helpful assistant answering questions about technical documentation.
135
+
136
+ Context:
137
+ {context}
138
+
139
+ Question: {query} [/INST]"""
140
+
141
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=2048, truncation=True).to(device)
142
+
143
+ with torch.no_grad():
144
+ outputs = model.generate(
145
+ **inputs,
146
+ max_new_tokens=512,
147
+ temperature=0.7,
148
+ top_p=0.95,
149
+ do_sample=True
150
+ )
151
+
152
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
153
+ return response.split("[/INST]")[1].strip() if "[/INST]" in response else response
154
+
155
+ # ========================================
156
+ # 🎨 STREAMLIT UI
157
+ # ========================================
158
+ st.title("πŸ€– BeRU Chat Assistant")
159
+ st.markdown("*100% Offline RAG System with Mistral 7B & VLM2Vec*")
160
+
161
+ # Sidebar Configuration
162
+ with st.sidebar:
163
+ st.header("βš™οΈ Configuration")
164
+
165
+ device_info = "🟒 GPU" if torch.cuda.is_available() else "πŸ”΄ CPU"
166
+ st.metric("Device", device_info)
167
+
168
+ num_results = st.slider("Retrieve top K documents", 1, 10, 5)
169
+ temperature = st.slider("Response Temperature", 0.1, 1.0, 0.7)
170
+
171
+ st.divider()
172
+ st.markdown("### πŸ“Š Project Info")
173
+ st.markdown("""
174
+ - **Model**: Mistral 7B Instruct v0.3
175
+ - **Embeddings**: VLM2Vec-Qwen2VL-2B
176
+ - **Vector Store**: FAISS with 10K+ documents
177
+ - **Retrieval**: Hybrid (Dense + BM25)
178
+ """)
179
+
180
+ # Main Chat Interface
181
+ col1, col2 = st.columns([3, 1])
182
+
183
+ with col1:
184
+ st.subheader("πŸ’¬ Ask a Question")
185
+
186
+ with col2:
187
+ if st.button("πŸ”„ Clear Chat", use_container_width=True):
188
+ st.session_state.messages = []
189
+ st.rerun()
190
+
191
+ # Initialize session state
192
+ if "messages" not in st.session_state:
193
+ st.session_state.messages = []
194
+
195
+ if "models_loaded" not in st.session_state:
196
+ st.session_state.models_loaded = False
197
+
198
+ # Load models
199
+ if not st.session_state.models_loaded:
200
+ st.info("πŸ“¦ Loading models on first run... This may take 2-3 minutes.")
201
+
202
+ try:
203
+ embed_model, processor, tokenizer_embed, embed_device = load_embedding_model()
204
+ llm_model, tokenizer_llm, llm_device = load_llm_model()
205
+ faiss_idx = load_faiss_index()
206
+
207
+ st.session_state.embed_model = embed_model
208
+ st.session_state.processor = processor
209
+ st.session_state.tokenizer_embed = tokenizer_embed
210
+ st.session_state.embed_device = embed_device
211
+ st.session_state.llm_model = llm_model
212
+ st.session_state.tokenizer_llm = tokenizer_llm
213
+ st.session_state.llm_device = llm_device
214
+ st.session_state.faiss_idx = faiss_idx
215
+ st.session_state.models_loaded = True
216
+ st.success("βœ… All models loaded successfully!")
217
+
218
+ except Exception as e:
219
+ st.error(f"❌ Error loading models: {str(e)}")
220
+ st.stop()
221
+
222
+ # Chat Interface
223
+ st.markdown("---")
224
+
225
+ # Display chat history
226
+ for message in st.session_state.messages:
227
+ with st.chat_message(message["role"]):
228
+ st.markdown(message["content"])
229
+
230
+ # User input
231
+ user_input = st.chat_input("Type your question here...", key="user_input")
232
+
233
+ if user_input:
234
+ # Add user message to chat
235
+ st.session_state.messages.append({"role": "user", "content": user_input})
236
+
237
+ with st.chat_message("user"):
238
+ st.markdown(user_input)
239
+
240
+ # Generate response
241
+ with st.chat_message("assistant"):
242
+ st.write("πŸ” Retrieving relevant documents...")
243
+
244
+ # Retrieve documents
245
+ retrieved = retrieve_documents(
246
+ user_input,
247
+ st.session_state.embed_model,
248
+ st.session_state.processor,
249
+ st.session_state.tokenizer_embed,
250
+ st.session_state.embed_device,
251
+ st.session_state.faiss_idx,
252
+ k=num_results
253
+ )
254
+
255
+ context = "\n\n".join([f"Document {i+1}: Context from index {doc['index']}"
256
+ for i, doc in enumerate(retrieved)])
257
+
258
+ st.write("πŸ’­ Generating response...")
259
+
260
+ # Generate response
261
+ response = generate_response(
262
+ user_input,
263
+ context,
264
+ st.session_state.llm_model,
265
+ st.session_state.tokenizer_llm,
266
+ st.session_state.llm_device
267
+ )
268
+
269
+ st.markdown(response)
270
+
271
+ # Add to chat history
272
+ st.session_state.messages.append({"role": "assistant", "content": response})
273
+
274
+ # Footer
275
+ st.markdown("---")
276
+ st.markdown("""
277
+ <div style='text-align: center; color: gray; font-size: 12px;'>
278
+ <p>BeRU Chat Assistant | Powered by Mistral 7B + VLM2Vec | 100% Offline</p>
279
+ <p><a href='https://github.com/AnwinJosy/BeRU'>GitHub</a> |
280
+ <a href='https://huggingface.co/AnwinJosy'>Hugging Face</a></p>
281
+ </div>
282
+ """, unsafe_allow_html=True)
down.py ADDED
@@ -0,0 +1,898 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import logging
4
+ import asyncio
5
+ import re
6
+ from pathlib import Path
7
+ from typing import List, Dict, Optional, Any
8
+ from contextlib import asynccontextmanager
9
+ from logging.handlers import RotatingFileHandler
10
+
11
+ # --- LANGCHAIN IMPORTS ---
12
+ from langchain_community.vectorstores import FAISS
13
+ from langchain.chains import create_history_aware_retriever
14
+ from langchain.chains.retrieval import create_retrieval_chain
15
+ from langchain.chains.combine_documents import create_stuff_documents_chain
16
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
17
+ from langchain_community.llms import HuggingFacePipeline
18
+ from langchain_core.embeddings import Embeddings
19
+ from langchain_core.messages import HumanMessage, AIMessage
20
+ from langchain_community.retrievers import BM25Retriever
21
+ from langchain.retrievers import EnsembleRetriever
22
+ from langchain.retrievers.multi_query import MultiQueryRetriever
23
+ from langchain_core.runnables import RunnablePassthrough
24
+ from langchain_core.output_parsers import StrOutputParser
25
+ from operator import itemgetter
26
+
27
+ # --- RERANKING IMPORTS ---
28
+ # Ensure you have installed flashrank: pip install flashrank
29
+ from langchain.retrievers import ContextualCompressionRetriever
30
+ from langchain_community.document_compressors import FlashrankRerank
31
+
32
+ # --- TRANSFORMERS IMPORTS ---
33
+ from transformers import (
34
+ AutoTokenizer,
35
+ AutoModelForCausalLM,
36
+ AutoModel,
37
+ pipeline,
38
+ BitsAndBytesConfig
39
+ )
40
+
41
+ # --- FASTAPI IMPORTS ---
42
+ from fastapi import FastAPI
43
+ from fastapi.responses import HTMLResponse, JSONResponse
44
+ from fastapi.middleware.cors import CORSMiddleware
45
+ from pydantic import BaseModel, Field, field_validator
46
+ import uvicorn
47
+ import numpy as np
48
+
49
+ # -------------------------------------------------------------------------
50
+ # 1. Pydantic Patch (Crucial for offline serialization)
51
+ # -------------------------------------------------------------------------
52
+ def patch_pydantic_for_pickle():
53
+ try:
54
+ from pydantic.v1.main import BaseModel as PydanticV1BaseModel
55
+ original_setstate = PydanticV1BaseModel.__setstate__
56
+
57
+ def patched_setstate(self, state):
58
+ if '__fields_set__' not in state:
59
+ state['__fields_set__'] = set(state.get('__dict__', {}).keys())
60
+ if '__private_attribute_values__' not in state:
61
+ state['__private_attribute_values__'] = {}
62
+ try:
63
+ original_setstate(self, state)
64
+ except Exception as e:
65
+ object.__setattr__(self, '__dict__', state.get('__dict__', {}))
66
+ object.__setattr__(self, '__fields_set__', state.get('__fields_set__', set()))
67
+ object.__setattr__(self, '__private_attribute_values__', state.get('__private_attribute_values__', {}))
68
+
69
+ PydanticV1BaseModel.__setstate__ = patched_setstate
70
+ print("βœ… Pydantic v1 patched for pickle compatibility")
71
+
72
+ except ImportError:
73
+ try:
74
+ import pydantic.v1 as pydantic_v1
75
+ from pydantic.v1 import BaseModel
76
+ original_setstate = BaseModel.__setstate__
77
+
78
+ def patched_setstate(self, state):
79
+ if '__fields_set__' not in state:
80
+ state['__fields_set__'] = set(state.get('__dict__', {}).keys())
81
+ if '__private_attribute_values__' not in state:
82
+ state['__private_attribute_values__'] = {}
83
+ try:
84
+ original_setstate(self, state)
85
+ except:
86
+ object.__setattr__(self, '__dict__', state.get('__dict__', {}))
87
+ object.__setattr__(self, '__fields_set__', state.get('__fields_set__', set()))
88
+
89
+ BaseModel.__setstate__ = patched_setstate
90
+ print("βœ… Pydantic patched for pickle compatibility")
91
+
92
+ except Exception as e:
93
+ print(f"⚠️ Could not patch Pydantic: {e}")
94
+
95
+ patch_pydantic_for_pickle()
96
+
97
+ # -------------------------------------------------------------------------
98
+ # 2. Configuration & Paths (workspace-agnostic)
99
+ # -------------------------------------------------------------------------
100
+ # environment variables allow overrides when running in containers / Spaces
101
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
102
+ os.environ["TRANSFORMERS_OFFLINE"] = "1"
103
+ os.environ["HF_DATASETS_OFFLINE"] = "1"
104
+ os.environ["HF_HUB_OFFLINE"] = "1"
105
+
106
+ # base directory for application files inside a container
107
+ ROOT_DIR = Path(os.environ.get("APP_ROOT", "/app")).resolve()
108
+
109
+ # model and index locations can be provided via env; defaults point into /app
110
+ MODEL_DIR = Path(os.environ.get("MODEL_DIR", ROOT_DIR / "models"))
111
+ LLM_MODEL_PATH = Path(os.environ.get("LLM_MODEL_PATH", MODEL_DIR / "Mistral-7B-Instruct-v0.3"))
112
+ EMBED_MODEL_PATH = Path(os.environ.get("EMBED_MODEL_PATH", MODEL_DIR / "VLM2Vec-Qwen2VL-2B"))
113
+ FAISS_INDEX_PATH = Path(os.environ.get("FAISS_INDEX_PATH", ROOT_DIR / "VLM2Vec-V2rag3"))
114
+
115
+ # Increased timeout for reranking operations
116
+ GENERATION_TIMEOUT = 240
117
+ LLM_MODEL = str(LLM_MODEL_PATH)
118
+ EMBED_MODEL = str(EMBED_MODEL_PATH)
119
+
120
+ # Logging Setup
121
+ logger = logging.getLogger("rag_system")
122
+ handler = RotatingFileHandler("rag.log", maxBytes=10 * 1024 * 1024, backupCount=5)
123
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
124
+ handler.setFormatter(formatter)
125
+ logger.addHandler(handler)
126
+ logger.setLevel(logging.INFO)
127
+
128
+ # Global Variables
129
+ vectorstore = None
130
+ llm_pipeline = None
131
+ qa_chain = None
132
+ answer_cache: Dict[str, Dict] = {}
133
+ conversations: Dict[str, List[Dict]] = {}
134
+
135
+ # -------------------------------------------------------------------------
136
+ # 3. VLM2Vec Embedding Class (Preserved)
137
+ # -------------------------------------------------------------------------
138
+ class VLM2VecEmbeddings(Embeddings):
139
+ def __init__(self, model_path: str, device: str = "cpu"):
140
+ print(f"πŸ”— Loading VLM2Vec model from: {model_path}")
141
+
142
+ self.device = device
143
+ self.model_path = model_path
144
+
145
+ self.tokenizer = AutoTokenizer.from_pretrained(
146
+ model_path,
147
+ trust_remote_code=True,
148
+ local_files_only=True,
149
+ )
150
+
151
+ if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None:
152
+ self.tokenizer.pad_token = self.tokenizer.eos_token
153
+
154
+ device_map = "auto" if device == "cuda" else "cpu"
155
+ dtype = torch.float16 if device == "cuda" else torch.float32
156
+
157
+ self.model = AutoModel.from_pretrained(
158
+ model_path,
159
+ trust_remote_code=True,
160
+ dtype=dtype,
161
+ device_map=device_map,
162
+ local_files_only=True,
163
+ )
164
+
165
+ self.model.eval()
166
+
167
+ try:
168
+ self.model_device = next(self.model.parameters()).device
169
+ except:
170
+ self.model_device = torch.device("cuda" if device == "cuda" else "cpu")
171
+
172
+ with torch.no_grad():
173
+ test_input = self.tokenizer("test", return_tensors="pt", add_special_tokens=True)
174
+ test_input = {k: v.to(self.model_device) for k, v in test_input.items()}
175
+ out = self.model(**test_input, output_hidden_states=True)
176
+ self.embedding_dim = out.hidden_states[-1].shape[-1]
177
+
178
+ print(f"βœ… VLM2Vec loaded on {self.model_device} | dim={self.embedding_dim}\n")
179
+
180
+ def _normalize_text(self, text: str) -> str:
181
+ text = re.sub(r'\s+', ' ', text or "")
182
+ text = re.sub(r'Page \d+', '', text, flags=re.IGNORECASE)
183
+ return text.strip()
184
+
185
+ def _ensure_non_empty(self, text: str) -> str:
186
+ t = self._normalize_text(text)
187
+ return t if t else "[EMPTY]"
188
+
189
+ def _embed_single(self, text: str) -> List[float]:
190
+ try:
191
+ with torch.no_grad():
192
+ clean_text = self._ensure_non_empty(text)
193
+
194
+ inputs = self.tokenizer(
195
+ clean_text,
196
+ return_tensors="pt",
197
+ add_special_tokens=True,
198
+ padding=True,
199
+ truncation=True,
200
+ max_length=512
201
+ )
202
+ inputs = {k: v.to(self.model_device) for k, v in inputs.items()}
203
+
204
+ outputs = self.model(**inputs, output_hidden_states=True)
205
+
206
+ if hasattr(outputs, "hidden_states") and outputs.hidden_states is not None:
207
+ hidden_states = outputs.hidden_states[-1]
208
+ attention_mask = inputs["attention_mask"].unsqueeze(-1).float()
209
+
210
+ weighted = hidden_states * attention_mask
211
+ sum_embeddings = weighted.sum(dim=1)
212
+ sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9)
213
+ embedding = (sum_embeddings / sum_mask).squeeze(0)
214
+ else:
215
+ embedding = outputs.logits.mean(dim=1).squeeze(0)
216
+
217
+ return embedding.cpu().numpy().tolist()
218
+
219
+ except Exception as e:
220
+ logger.error(f"VLM2Vec embedding error: {e}")
221
+ return [0.0] * getattr(self, "embedding_dim", 1024)
222
+
223
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
224
+ return [self._embed_single(t) for t in texts]
225
+
226
+ def embed_query(self, text: str) -> List[float]:
227
+ return self._embed_single(text)
228
+
229
+ # -------------------------------------------------------------------------
230
+ # 4. Prompt Templates (CLEANER & STRICTER)
231
+ # -------------------------------------------------------------------------
232
+ PROMPT_TEMPLATES = {
233
+ "Short and Concise": """<s>[INST] Answer the question based ONLY on the following context. Keep the answer under 3 sentences.
234
+
235
+ Context:
236
+ {context}
237
+
238
+ Question:
239
+ {input} [/INST]""",
240
+
241
+ "Detailed": """<s>[INST] You are a helpful assistant. Answer the question using ONLY the following context. Provide a detailed summary (4-5 sentences).
242
+
243
+ Context:
244
+ {context}
245
+
246
+ Question:
247
+ {input} [/INST]""",
248
+
249
+ "Step-by-Step": """<s>[INST] Based on the context below, provide a step-by-step procedure to answer the question.
250
+
251
+ Context:
252
+ {context}
253
+
254
+ Question:
255
+ {input} [/INST]""",
256
+ }
257
+
258
+ def structure_answer(answer: str, style: str) -> str:
259
+ # 1. REMOVE "Enough thinking" and specific artifacts
260
+ artifacts = [
261
+ "Enough thinking",
262
+ "Note:",
263
+ "System:",
264
+ "User:",
265
+ "[/INST]",
266
+ "Here is the answer:",
267
+ "Answer:"
268
+ ]
269
+
270
+ for artifact in artifacts:
271
+ if artifact in answer:
272
+ # If it's "Enough thinking", just delete the phrase
273
+ answer = answer.replace(artifact, "")
274
+
275
+ # 2. SPLIT at likely hallucination points
276
+ # If the model starts writing "Human:" or "Question:" again, STOP there.
277
+ stop_markers = ["Human:", "Question:", "User input:", "Context:"]
278
+ for marker in stop_markers:
279
+ if marker in answer:
280
+ answer = answer.split(marker)[0]
281
+
282
+ clean_answer = answer.strip()
283
+
284
+ # 3. Final Formatting
285
+ if style == "Short and Concise":
286
+ sentences = clean_answer.split('.')
287
+ clean_answer = ". ".join(sentences[:2]) + "."
288
+
289
+ return clean_answer
290
+ # -------------------------------------------------------------------------
291
+ # 5. Load System
292
+ # -------------------------------------------------------------------------
293
+ def load_system():
294
+ global vectorstore, llm_pipeline, qa_chain
295
+
296
+ if not os.path.exists(LLM_MODEL_PATH):
297
+ raise FileNotFoundError(f"LLM model not found at: {LLM_MODEL_PATH}")
298
+ if not os.path.exists(EMBED_MODEL_PATH):
299
+ raise FileNotFoundError(f"Embedding model not found at: {EMBED_MODEL_PATH}")
300
+ if not os.path.exists(FAISS_INDEX_PATH):
301
+ raise FileNotFoundError(
302
+ f"FAISS index not found at: {FAISS_INDEX_PATH}\n"
303
+ f"Please run the rebuild_faiss_index.py script first!"
304
+ )
305
+
306
+ print("\n" + "=" * 70)
307
+ print("πŸš€ LOADING RAG SYSTEM: Mistral 7B + VLM2Vec + Reranking (OFFLINE)")
308
+ print("=" * 70 + "\n")
309
+
310
+ _load_vectorstore()
311
+ _load_llm()
312
+ _build_retrieval_chain()
313
+
314
+ print("βœ… RAG system ready (100% OFFLINE)!\n")
315
+
316
+
317
+ def _load_embeddings():
318
+ device = "cuda" if torch.cuda.is_available() else "cpu"
319
+ embedding_model = VLM2VecEmbeddings(
320
+ model_path=EMBED_MODEL_PATH,
321
+ device=device,
322
+ )
323
+ return embedding_model
324
+
325
+
326
+ def _load_vectorstore():
327
+ global vectorstore
328
+
329
+ import faiss
330
+ import pickle
331
+ from langchain_community.docstore.in_memory import InMemoryDocstore
332
+ from langchain_core.documents import Document
333
+
334
+ print(f"πŸ“₯ Loading FAISS index from: {FAISS_INDEX_PATH}")
335
+
336
+ text_index_path = os.path.join(FAISS_INDEX_PATH, "text_index.faiss")
337
+ text_docs_path = os.path.join(FAISS_INDEX_PATH, "text_documents.pkl")
338
+
339
+ if not os.path.exists(text_index_path):
340
+ raise FileNotFoundError(f"text_index.faiss not found")
341
+ if not os.path.exists(text_docs_path):
342
+ raise FileNotFoundError(f"text_documents.pkl not found")
343
+
344
+ embedding_model = _load_embeddings()
345
+
346
+ try:
347
+ index = faiss.read_index(text_index_path)
348
+ print(f" πŸ“Š FAISS index loaded: {index.ntotal} vectors")
349
+
350
+ print(" πŸ“„ Loading documents...")
351
+
352
+ documents = None
353
+
354
+ # Robust loading mechanism
355
+ try:
356
+ import pickle5
357
+ with open(text_docs_path, 'rb') as f:
358
+ documents = pickle5.load(f)
359
+ print(" βœ… Loaded with pickle5")
360
+ except (ImportError, Exception) as e:
361
+ pass
362
+
363
+ if documents is None:
364
+ try:
365
+ with open(text_docs_path, 'rb') as f:
366
+ documents = pickle.load(f, encoding='latin1')
367
+ print(" βœ… Loaded with latin1 encoding")
368
+ except Exception as e:
369
+ pass
370
+
371
+ if documents is None:
372
+ try:
373
+ import dill
374
+ with open(text_docs_path, 'rb') as f:
375
+ documents = dill.load(f)
376
+ print(" βœ… Loaded with dill")
377
+ except Exception as e:
378
+ print(f" ⚠️ dill failed: {e}")
379
+ raise RuntimeError("Could not load documents. Check pickle version.")
380
+
381
+ if isinstance(documents, list):
382
+ print(f" Loaded {len(documents)} documents")
383
+
384
+ reconstructed_docs = []
385
+ for doc in documents:
386
+ if isinstance(doc, Document):
387
+ reconstructed_docs.append(doc)
388
+ else:
389
+ try:
390
+ new_doc = Document(
391
+ page_content=doc.page_content if hasattr(doc, 'page_content') else str(doc),
392
+ metadata=doc.metadata if hasattr(doc, 'metadata') else {}
393
+ )
394
+ reconstructed_docs.append(new_doc)
395
+ except Exception as e:
396
+ print(f" ⚠️ Could not reconstruct document: {e}")
397
+
398
+ documents = reconstructed_docs
399
+ docstore = InMemoryDocstore({str(i): doc for i, doc in enumerate(documents)})
400
+ index_to_docstore_id = {i: str(i) for i in range(len(documents))}
401
+
402
+ elif isinstance(documents, dict):
403
+ print(f" Loaded {len(documents)} documents (dict)")
404
+ docstore = InMemoryDocstore(documents)
405
+ index_to_docstore_id = {i: key for i, key in enumerate(documents.keys())}
406
+
407
+ else:
408
+ raise ValueError(f"Unexpected documents format: {type(documents)}")
409
+
410
+ vectorstore = FAISS(
411
+ embedding_function=embedding_model,
412
+ index=index,
413
+ docstore=docstore,
414
+ index_to_docstore_id=index_to_docstore_id
415
+ )
416
+
417
+ print(f" πŸ“Š Total vectors: {vectorstore.index.ntotal}")
418
+ print("βœ… FAISS vectorstore loaded\n")
419
+
420
+ except Exception as e:
421
+ print(f"❌ Error loading FAISS index: {e}")
422
+ import traceback
423
+ traceback.print_exc()
424
+ raise
425
+
426
+
427
+ def _load_llm():
428
+ print(f"πŸ€– Loading LLM from: {LLM_MODEL_PATH} (OFFLINE - SPEED OPTIMIZED)")
429
+
430
+ bnb_config = BitsAndBytesConfig(
431
+ load_in_4bit=True,
432
+ bnb_4bit_quant_type="nf4",
433
+ bnb_4bit_compute_dtype=torch.float16,
434
+ bnb_4bit_use_double_quant=True,
435
+ )
436
+
437
+ tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH, local_files_only=True)
438
+ if tokenizer.pad_token_id is None:
439
+ tokenizer.pad_token = tokenizer.eos_token
440
+
441
+ # CHECK FOR FLASH ATTENTION SUPPORT
442
+ # (Fall back to standard if not supported)
443
+ try:
444
+ model = AutoModelForCausalLM.from_pretrained(
445
+ LLM_MODEL_PATH,
446
+ quantization_config=bnb_config,
447
+ device_map="auto",
448
+ local_files_only=True,
449
+ attn_implementation="flash_attention_2" # <--- SPEED BOOST
450
+ )
451
+ print(" ⚑ Flash Attention 2 Enabled!")
452
+ except:
453
+ print(" ⚠️ Flash Attention 2 not supported. Using standard attention.")
454
+ model = AutoModelForCausalLM.from_pretrained(
455
+ LLM_MODEL_PATH,
456
+ quantization_config=bnb_config,
457
+ device_map="auto",
458
+ local_files_only=True,
459
+ )
460
+
461
+ pipe = pipeline(
462
+ "text-generation",
463
+ model=model,
464
+ tokenizer=tokenizer,
465
+ max_new_tokens=512,
466
+ do_sample=True,
467
+ temperature=0.01,
468
+ top_p=0.95,
469
+ pad_token_id=tokenizer.eos_token_id,
470
+ return_full_text=False # Stop repetition
471
+ )
472
+
473
+ global llm_pipeline
474
+ llm_pipeline = HuggingFacePipeline(pipeline=pipe)
475
+ print("βœ… LLM Loaded\n")
476
+
477
+ def format_docs_with_sources(docs):
478
+ """
479
+ Combines document content with its metadata (Source File & Page).
480
+ """
481
+ formatted_entries = []
482
+ for doc in docs:
483
+ # Extract metadata (default to 'Unknown' if missing)
484
+ source = doc.metadata.get("source", "Unknown Document")
485
+ # Optional: Clean the path to just show filename
486
+ # source = source.split("\\")[-1]
487
+ page = doc.metadata.get("page", "?")
488
+
489
+ entry = f"--- REFERENCE: {source} (Page {page}) ---\n{doc.page_content}\n"
490
+ formatted_entries.append(entry)
491
+
492
+ return "\n\n".join(formatted_entries)
493
+
494
+
495
+ def _build_retrieval_chain():
496
+ global qa_chain
497
+ print("πŸ”— Building Production RAG Chain (Sources + Hybrid)...")
498
+
499
+ # --- A. RETRIEVER SETUP (Speed Optimized) ---
500
+
501
+ # 1. Vector Search
502
+ faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
503
+
504
+ # 2. BM25 (Keyword Search)
505
+ try:
506
+ all_docs = list(vectorstore.docstore._dict.values())
507
+ bm25_retriever = BM25Retriever.from_documents(all_docs)
508
+ bm25_retriever.k = 10
509
+
510
+ ensemble_retriever = EnsembleRetriever(
511
+ retrievers=[faiss_retriever, bm25_retriever],
512
+ weights=[0.3, 0.7]
513
+ )
514
+ except:
515
+ ensemble_retriever = faiss_retriever
516
+
517
+ # 3. Reranking (Top 5 only)
518
+ try:
519
+ compressor = FlashrankRerank(model="ms-marco-MiniLM-L-12-v2", top_n=5)
520
+ final_retriever = ContextualCompressionRetriever(
521
+ base_compressor=compressor,
522
+ base_retriever=ensemble_retriever
523
+ )
524
+ except:
525
+ final_retriever = ensemble_retriever
526
+
527
+ # --- B. HISTORY AWARENESS ---
528
+
529
+ # Reformulate question based on chat history
530
+ rephrase_prompt = ChatPromptTemplate.from_template(
531
+ """<s>[INST] Rephrase the follow-up question to be a standalone question.
532
+ Chat History: {chat_history}
533
+ Follow Up Input: {input}
534
+ Standalone question: [/INST]"""
535
+ )
536
+
537
+ history_node = create_history_aware_retriever(
538
+ llm_pipeline,
539
+ final_retriever,
540
+ rephrase_prompt
541
+ )
542
+
543
+ # --- C. FINAL ANSWER GENERATION (With Sources) ---
544
+
545
+ qa_prompt = ChatPromptTemplate.from_template(
546
+ """[INST] You are a helpful assistant for BPCL-Kochi Refinery.
547
+ Answer the user's question based strictly on the context provided below.
548
+ If the answer is not in the context, say "I don't have that information in the manuals."
549
+ ALWAYS cite the document name for your answer.
550
+
551
+ CONTEXT WITH SOURCES:
552
+ {context}
553
+
554
+ USER QUESTION:
555
+ {input}
556
+
557
+ ANSWER: [/INST]"""
558
+ )
559
+
560
+ # The Chain (No Cache)
561
+ qa_chain = (
562
+ {
563
+ "context": history_node | format_docs_with_sources,
564
+ "input": itemgetter("input"),
565
+ "chat_history": itemgetter("chat_history"),
566
+ }
567
+ | qa_prompt
568
+ | llm_pipeline
569
+ | StrOutputParser()
570
+ )
571
+
572
+ print("βœ… Production Chain Built (with Citations)\n")
573
+ # -------------------------------------------------------------------------
574
+ # 6. FastAPI App & Endpoints
575
+ # -------------------------------------------------------------------------
576
+ @asynccontextmanager
577
+ async def lifespan(app: FastAPI):
578
+ print("\nπŸš€ Starting application (OFFLINE)...")
579
+ load_system()
580
+ logger.info("RAG system initialized (OFFLINE)")
581
+
582
+ yield
583
+
584
+ print("\nπŸ›‘ Shutting down...")
585
+ answer_cache.clear()
586
+ if torch.cuda.is_available():
587
+ torch.cuda.empty_cache()
588
+ logger.info("Shutdown complete")
589
+
590
+
591
+ app = FastAPI(
592
+ title="BeRU Chat Assistant - VLM2Vec",
593
+ description="100% Offline RAG system with VLM2Vec embeddings",
594
+ version="2.0-VLM2Vec",
595
+ lifespan=lifespan,
596
+ )
597
+
598
+ app.add_middleware(
599
+ CORSMiddleware,
600
+ allow_origins=["*"],
601
+ allow_credentials=True,
602
+ allow_methods=["*"],
603
+ allow_headers=["*"],
604
+ )
605
+
606
+
607
+ class ChatRequest(BaseModel):
608
+ message: str = Field(..., min_length=1, max_length=2000)
609
+ mode: str = "Detailed"
610
+ session_id: Optional[str] = "default"
611
+ include_images: bool = False
612
+
613
+ @field_validator("message")
614
+ @classmethod
615
+ def sanitize_message(cls, v):
616
+ return v.strip()
617
+
618
+ @field_validator("mode")
619
+ @classmethod
620
+ def validate_mode(cls, v):
621
+ if v not in PROMPT_TEMPLATES:
622
+ return "Detailed"
623
+ return v
624
+
625
+
626
+ class QueryRequest(BaseModel):
627
+ message: str = Field(..., min_length=1, max_length=2000)
628
+ answer_style: str = "Detailed"
629
+ num_sources: int = Field(default=5, ge=1, le=10)
630
+
631
+ @field_validator("message")
632
+ @classmethod
633
+ def sanitize_message(cls, v):
634
+ return v.strip()
635
+
636
+ @field_validator("answer_style")
637
+ @classmethod
638
+ def validate_style(cls, v):
639
+ if v not in PROMPT_TEMPLATES:
640
+ return "Detailed"
641
+ return v
642
+
643
+
644
+ @app.get("/", response_class=HTMLResponse)
645
+ async def root():
646
+ try:
647
+ frontend_path = Path("frontend.html")
648
+ if frontend_path.exists():
649
+ with open(frontend_path, "r", encoding="utf-8") as f:
650
+ return f.read()
651
+ else:
652
+ return f"""
653
+ <html>
654
+ <body>
655
+ <h1>Error: frontend.html not found</h1>
656
+ <p>Please place frontend.html in the same directory as this script</p>
657
+ <p>Current directory: {Path.cwd()}</p>
658
+ </body>
659
+ </html>
660
+ """
661
+ except Exception as e:
662
+ return f"<html><body><h1>Error loading frontend</h1><p>{str(e)}</p></body></html>"
663
+
664
+
665
+ query_semaphore = asyncio.Semaphore(3)
666
+
667
+
668
+ @app.post("/api/chat")
669
+ async def chat_endpoint(request: ChatRequest):
670
+ async with query_semaphore:
671
+ try:
672
+ message = request.message
673
+ mode = request.mode
674
+ session_id = request.session_id
675
+
676
+ logger.info(f"Chat Query: {message[:100]} | Mode: {mode}")
677
+ print(f"\n{'=' * 60}")
678
+ print(f"πŸ’¬ Chat: {message}")
679
+ print(f" Mode: {mode}")
680
+ print(f" Session: {session_id}")
681
+
682
+ # History Management
683
+ if session_id not in conversations:
684
+ conversations[session_id] = []
685
+
686
+ # Check Cache
687
+ cache_key = f"{message}_{mode}_{session_id}"
688
+ if cache_key in answer_cache:
689
+ print("πŸ’Ύ Cache hit!")
690
+ cached_response = answer_cache[cache_key]
691
+ conversations[session_id].append(
692
+ {
693
+ "user": message,
694
+ "bot": cached_response["response"],
695
+ "mode": mode,
696
+ }
697
+ )
698
+ return JSONResponse(cached_response)
699
+
700
+ print(f"⏱️ Generating response (timeout: {GENERATION_TIMEOUT}s)...")
701
+
702
+ # Convert dict history to LangChain Objects (Last 3 turns)
703
+ chat_history_objs = []
704
+ for turn in conversations[session_id][-3:]:
705
+ # Ensure you have these imported from langchain_core.messages
706
+ chat_history_objs.append(HumanMessage(content=turn["user"]))
707
+ chat_history_objs.append(AIMessage(content=turn["bot"]))
708
+
709
+ # Execute Chain
710
+ try:
711
+ result = await asyncio.wait_for(
712
+ asyncio.to_thread(
713
+ qa_chain.invoke,
714
+ {
715
+ "input": message,
716
+ "chat_history": chat_history_objs
717
+ },
718
+ ),
719
+ timeout=GENERATION_TIMEOUT,
720
+ )
721
+ except asyncio.TimeoutError:
722
+ return JSONResponse(
723
+ {
724
+ "error": f"Query timeout after {GENERATION_TIMEOUT}s",
725
+ "response": "Sorry, the request took too long. Please try again.",
726
+ },
727
+ status_code=504,
728
+ )
729
+
730
+ # --- CRITICAL FIX START ---
731
+ # The new chain returns a String directly. The old one returned a Dict.
732
+ # We must handle both cases to prevent the AttributeError.
733
+
734
+ context_docs = [] # Default to empty if using string chain
735
+
736
+ if isinstance(result, str):
737
+ # New "Production Chain" path
738
+ answer = result
739
+ # Note: In this mode, citations are embedded in the text string
740
+ # (e.g. "Reference: Manual..."), so we don't have raw docs for the 'sources' list.
741
+ elif isinstance(result, dict):
742
+ # Old "Standard Chain" path
743
+ answer = result.get("answer", "No answer generated")
744
+ context_docs = result.get("context", [])
745
+ else:
746
+ answer = str(result)
747
+
748
+ # Clean up the answer text
749
+ answer = structure_answer(answer, mode)
750
+ # --- CRITICAL FIX END ---
751
+
752
+ # Process Sources (Only populates if context_docs were returned)
753
+ sources = []
754
+ for i, doc in enumerate(context_docs[:5], 1):
755
+ sources.append(
756
+ {
757
+ "index": i,
758
+ "file_name": doc.metadata.get("source", "Unknown"),
759
+ "page": doc.metadata.get("page", "N/A"),
760
+ "snippet": doc.page_content[:200].replace("\n", " "),
761
+ }
762
+ )
763
+
764
+ print(f"βœ… Response generated: {len(answer)} chars")
765
+
766
+ response_data = {
767
+ "response": answer,
768
+ "sources": sources,
769
+ "mode": mode,
770
+ "cached": False,
771
+ "images": [] # Placeholder for image handling
772
+ }
773
+
774
+ answer_cache[cache_key] = response_data
775
+
776
+ conversations[session_id].append(
777
+ {
778
+ "user": message,
779
+ "bot": answer,
780
+ "mode": mode,
781
+ }
782
+ )
783
+
784
+ logger.info("Chat response completed")
785
+ return JSONResponse(response_data)
786
+
787
+ except Exception as e:
788
+ logger.error(f"Chat error: {e}", exc_info=True)
789
+ print(f"❌ ERROR: {e}")
790
+ # Ensure traceback is printed to console for debugging
791
+ import traceback
792
+ traceback.print_exc()
793
+ return JSONResponse(
794
+ {
795
+ "error": str(e),
796
+ "response": "Sorry, an internal error occurred. Please check server logs.",
797
+ },
798
+ status_code=500,
799
+ )
800
+ @app.post("/api/query")
801
+ async def query_endpoint(request: QueryRequest):
802
+ chat_request = ChatRequest(
803
+ message=request.message,
804
+ mode=request.answer_style,
805
+ session_id="default",
806
+ )
807
+ response = await chat_endpoint(chat_request)
808
+ data = response.body.decode("utf-8")
809
+ import json
810
+
811
+ json_data = json.loads(data)
812
+ if "response" in json_data:
813
+ json_data["answer"] = json_data.pop("response")
814
+ return JSONResponse(json_data)
815
+
816
+
817
+ @app.get("/api/health")
818
+ async def health_check():
819
+ return {
820
+ "status": "ok",
821
+ "mode": "OFFLINE",
822
+ "llm_model": LLM_MODEL,
823
+ "embedding_model": EMBED_MODEL,
824
+ "cuda_available": torch.cuda.is_available(),
825
+ "cache_size": len(answer_cache),
826
+ "active_sessions": len(conversations),
827
+ }
828
+
829
+
830
+ @app.get("/api/stats")
831
+ async def get_stats():
832
+ try:
833
+ doc_count = len(vectorstore.docstore._dict) if vectorstore else 0
834
+ except Exception:
835
+ doc_count = "unknown"
836
+
837
+ return {
838
+ "mode": "OFFLINE",
839
+ "documents": doc_count,
840
+ "cache_size": len(answer_cache),
841
+ "active_sessions": len(conversations),
842
+ "llm_model": LLM_MODEL,
843
+ "embedding_model": EMBED_MODEL,
844
+ "cuda_available": torch.cuda.is_available(),
845
+ "index_path": FAISS_INDEX_PATH,
846
+ }
847
+
848
+
849
+ @app.post("/api/new-conversation")
850
+ async def new_conversation(request: dict):
851
+ session_id = request.get("session_id", "default")
852
+ if session_id in conversations:
853
+ conversations[session_id] = []
854
+ return {"message": "New conversation started", "session_id": session_id}
855
+
856
+
857
+ @app.get("/api/conversation/{session_id}")
858
+ async def get_conversation(session_id: str):
859
+ if session_id in conversations:
860
+ return {"history": conversations[session_id]}
861
+ return {"history": []}
862
+
863
+
864
+ @app.get("/api/clear_cache")
865
+ async def clear_cache():
866
+ cache_size = len(answer_cache)
867
+ answer_cache.clear()
868
+
869
+ if torch.cuda.is_available():
870
+ torch.cuda.empty_cache()
871
+
872
+ return {"message": f"Cache cleared. Removed {cache_size} entries"}
873
+
874
+
875
+ if __name__ == "__main__":
876
+ import sys
877
+ import argparse
878
+
879
+ parser = argparse.ArgumentParser()
880
+ parser.add_argument("--port", type=int, default=8001, help="Port to run the server on")
881
+ args = parser.parse_args()
882
+
883
+ port = args.port
884
+
885
+ print("\n" + "=" * 70)
886
+ print("πŸš€ BeRU Chat Assistant - VLM2Vec Mode (100% OFFLINE)")
887
+ print("=" * 70)
888
+ print(f"\nπŸ“ Frontend: http://localhost:{port}")
889
+ print(f"πŸ“ API Docs: http://localhost:{port}/docs")
890
+ print(f"πŸ“ Health: http://localhost:{port}/api/health")
891
+ print(f"πŸ“ Stats: http://localhost:{port}/api/stats")
892
+ print(f"\nπŸ”Œ Embedding Model (LOCAL): {EMBED_MODEL_PATH}")
893
+ print(f"πŸ”Œ LLM Model (LOCAL): {LLM_MODEL_PATH}")
894
+ print(f"πŸ”Œ FAISS Index: {FAISS_INDEX_PATH}")
895
+ print("πŸ”Œ Mode: 100% OFFLINE (local files only)")
896
+ print("=" * 70 + "\n")
897
+
898
+ uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")
frontend.html ADDED
@@ -0,0 +1,1075 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>BeRU Chat - Multimodal</title>
7
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css">
8
+
9
+ <style>
10
+ /* *** EXISTING STYLES (keeping all your original styles) *** */
11
+ body {
12
+ font-family: 'Roboto', sans-serif;
13
+ margin: 0;
14
+ padding: 0;
15
+ transition: background-color 0.5s ease, color 0.5s ease;
16
+ }
17
+
18
+ .light-mode {
19
+ background-color: #caf2fa;
20
+ color: #333;
21
+ }
22
+
23
+ .dark-mode {
24
+ background-color: #1e1e1e;
25
+ color: #f5f5f5;
26
+ }
27
+
28
+ .dark-mode .chat-container {
29
+ background-color: #1e1e1e;
30
+ color: #f5f5f5;
31
+ }
32
+
33
+ .sidebar {
34
+ width: 300px;
35
+ height: 100vh;
36
+ background-color: #0d131a;
37
+ position: fixed;
38
+ top: 0;
39
+ left: 0;
40
+ z-index: 1;
41
+ overflow-x: hidden;
42
+ transition: width 0.3s, background-color 0.5s ease;
43
+ }
44
+
45
+ .light-mode .sidebar {
46
+ background-color: #01414e;
47
+ }
48
+
49
+ .sidebar.collapsed {
50
+ width: 50px;
51
+ }
52
+
53
+ .tooltip {
54
+ position: absolute;
55
+ top: 0;
56
+ right: -20px;
57
+ }
58
+
59
+ .tooltip .tooltiptext {
60
+ visibility: hidden;
61
+ width: 120px;
62
+ background-color: rgb(0, 0, 0);
63
+ color: #fff;
64
+ text-align: center;
65
+ border-radius: 6px;
66
+ padding: 5px 0;
67
+ position: absolute;
68
+ z-index: 1;
69
+ bottom: 125%;
70
+ left: 50%;
71
+ margin-left: -60px;
72
+ opacity: 0;
73
+ transition: opacity 0.3s;
74
+ }
75
+
76
+ #sidebar-toggle {
77
+ background-color: transparent;
78
+ margin: -22%;
79
+ align-items: center;
80
+ margin-top: 600%;
81
+ }
82
+
83
+ .tooltip .tooltiptext::after {
84
+ content: "";
85
+ position: absolute;
86
+ top: 100%;
87
+ left: 50%;
88
+ border-width: 5px;
89
+ border-style: solid;
90
+ border-color: black transparent transparent transparent;
91
+ }
92
+
93
+ .tooltip:hover .tooltiptext {
94
+ visibility: visible;
95
+ opacity: 1;
96
+ }
97
+
98
+ .sidebar-content {
99
+ padding-top: 20px;
100
+ transition: opacity 0.3s;
101
+ }
102
+
103
+ .sidebar.collapsed .sidebar-content {
104
+ opacity: 0;
105
+ pointer-events: none;
106
+ }
107
+
108
+ .conversation-list {
109
+ padding: 0 20px;
110
+ }
111
+
112
+ .conversation {
113
+ margin-bottom: 10px;
114
+ }
115
+
116
+ .conversation-text {
117
+ font-weight: bold;
118
+ color: #fff;
119
+ transition: color 0.5s ease;
120
+ }
121
+
122
+ .light-mode .conversation-text {
123
+ color: #ccc;
124
+ }
125
+
126
+ .conversation-content {
127
+ color: #ddd;
128
+ transition: color 0.5s ease;
129
+ }
130
+
131
+ .light-mode .conversation-content {
132
+ color: #888;
133
+ }
134
+
135
+ #new-conversation-btn {
136
+ background-color: #3a3b3b;
137
+ color: #fff;
138
+ border: none;
139
+ padding: 10px 20px;
140
+ border-radius: 5px;
141
+ cursor: pointer;
142
+ transition: background-color 0.3s, color 0.5s ease;
143
+ }
144
+
145
+ #new-conversation-btn:hover {
146
+ background-color: #242020;
147
+ }
148
+
149
+ .light-mode #new-conversation-btn {
150
+ background-color: #c9c9c9;
151
+ color: #171717;
152
+ }
153
+
154
+ .light-mode #new-conversation-btn:hover {
155
+ background-color: #e0e0e0;
156
+ color: #171717;
157
+ }
158
+
159
+ .chat-container {
160
+ width: calc(100% - 300px);
161
+ margin-left: 300px;
162
+ height: 100vh;
163
+ overflow: hidden;
164
+ transition: all 0.3s ease-in-out, background-color 0.5s ease;
165
+ }
166
+
167
+ .sidebar.collapsed ~ .chat-container {
168
+ width: calc(100% - 50px);
169
+ margin-left: 50px;
170
+ }
171
+
172
+ .chat-content {
173
+ display: flex;
174
+ flex-direction: column;
175
+ height: 100%;
176
+ padding-bottom: 80px;
177
+ }
178
+
179
+ .logo-container {
180
+ display: flex;
181
+ align-items: center;
182
+ }
183
+
184
+ .logo {
185
+ width: 30px;
186
+ height: 30px;
187
+ margin-right: 10px;
188
+ }
189
+
190
+ .chat-header {
191
+ margin-left: 2%;
192
+ display: flex;
193
+ align-items: center;
194
+ justify-content: space-between;
195
+ font-size: 10px;
196
+ height: 60px;
197
+ background-color: #171717;
198
+ transition: background-color 0.5s ease, color 0.5s ease;
199
+ }
200
+
201
+ .light-mode h1 {
202
+ color: black;
203
+ }
204
+
205
+ .dark-mode h1 {
206
+ color: #f5f5f5;
207
+ }
208
+
209
+ .light-mode .chat-header {
210
+ background-color: #caf2fa;
211
+ color: #333;
212
+ }
213
+
214
+ .dark-mode .chat-header {
215
+ background-color: #1e1e1e;
216
+ color: #f5f5f5;
217
+ }
218
+
219
+ h1 {
220
+ color: #cfcfcf;
221
+ font-family: 'Trebuchet MS', sans-serif;
222
+ transition: color 0.5s ease;
223
+ }
224
+
225
+ /* Toggle Switch Styles */
226
+ .toggle-switch {
227
+ position: relative;
228
+ width: 50px;
229
+ height: 25px;
230
+ margin-right: 20px;
231
+ --light: #d8dbe0;
232
+ --dark: #28292c;
233
+ }
234
+
235
+ .switch-label {
236
+ position: absolute;
237
+ width: 100%;
238
+ height: 100%;
239
+ background-color: var(--dark);
240
+ border-radius: 12px;
241
+ cursor: pointer;
242
+ border: 1.5px solid var(--dark);
243
+ transition: background-color 0.3s;
244
+ }
245
+
246
+ .checkbox {
247
+ position: absolute;
248
+ display: none;
249
+ }
250
+
251
+ .slider {
252
+ position: absolute;
253
+ width: 100%;
254
+ height: 100%;
255
+ border-radius: 12px;
256
+ transition: 0.3s;
257
+ }
258
+
259
+ .checkbox:checked ~ .slider {
260
+ background-color: var(--light);
261
+ }
262
+
263
+ .slider::before {
264
+ content: "";
265
+ position: absolute;
266
+ top: 5.5px;
267
+ left: 5.5px;
268
+ width: 14px;
269
+ height: 14px;
270
+ border-radius: 50%;
271
+ box-shadow: inset 7px -2px 0px 0px var(--light);
272
+ background-color: var(--dark);
273
+ transition: 0.3s;
274
+ }
275
+
276
+ .checkbox:checked ~ .slider::before {
277
+ transform: translateX(26px);
278
+ background-color: var(--dark);
279
+ box-shadow: none;
280
+ }
281
+
282
+ .chat-box {
283
+ display: flex;
284
+ flex-direction: column;
285
+ flex: 1;
286
+ overflow-y: auto;
287
+ padding: 15px;
288
+ overflow-x: hidden;
289
+ }
290
+
291
+ .chat-box::-webkit-scrollbar {
292
+ width: 3px;
293
+ }
294
+
295
+ .chat-box::-webkit-scrollbar-track {
296
+ background: transparent;
297
+ }
298
+
299
+ .chat-box::-webkit-scrollbar-track-piece {
300
+ background: #b0b0b000;
301
+ border-radius: 999px;
302
+ }
303
+
304
+ .chat-box::-webkit-scrollbar-thumb {
305
+ background-color: #ffd700;
306
+ border-radius: 999px;
307
+ border: 2px solid transparent;
308
+ background-clip: padding-box;
309
+ }
310
+
311
+ .chat-box {
312
+ scrollbar-width: thin;
313
+ scrollbar-color: #ffd700 #b0b0b0;
314
+ }
315
+
316
+ .chat-box p {
317
+ margin: 10px 0;
318
+ font-size: 16px;
319
+ }
320
+
321
+ .messageBox {
322
+ position: fixed;
323
+ bottom: 20px;
324
+ left: 50%;
325
+ transform: translateX(-50%);
326
+ display: flex;
327
+ align-items: center;
328
+ background-color: #2d2d2d;
329
+ padding: 0 12px;
330
+ border-radius: 10px;
331
+ border: 1px solid rgb(63, 63, 63);
332
+ width: 60%;
333
+ max-width: 800px;
334
+ height: 50px;
335
+ transition: all 0.3s ease-in-out, background-color 0.5s ease, border-color 0.5s ease;
336
+ }
337
+
338
+ .light-mode .messageBox {
339
+ background-color: white;
340
+ border: 1px solid #1d495f;
341
+ }
342
+
343
+ .messageBox:focus-within {
344
+ border: 1px solid rgb(110, 110, 110);
345
+ }
346
+
347
+ #messageInput {
348
+ flex: 1;
349
+ height: 100%;
350
+ background-color: transparent;
351
+ outline: none;
352
+ border: none;
353
+ padding: 0 12px;
354
+ color: white;
355
+ width: auto;
356
+ font-family: 'Roboto', sans-serif;
357
+ font-size: 14px;
358
+ transition: color 0.5s ease;
359
+ }
360
+
361
+ .light-mode #messageInput {
362
+ color: #171717;
363
+ }
364
+
365
+ #sendButton {
366
+ width: 50px;
367
+ height: 100%;
368
+ background-color: transparent;
369
+ outline: none;
370
+ border: none;
371
+ display: flex;
372
+ align-items: center;
373
+ justify-content: center;
374
+ cursor: pointer;
375
+ padding: 0;
376
+ }
377
+
378
+ #sendButton svg {
379
+ height: 60%;
380
+ width: auto;
381
+ transition: all 0.3s;
382
+ }
383
+
384
+ #sendButton svg path {
385
+ transition: all 0.3s;
386
+ }
387
+
388
+ #sendButton:hover svg path,
389
+ #sendButton:active svg path {
390
+ fill: #3c3c3c;
391
+ stroke: white;
392
+ }
393
+
394
+ .light-mode #sendButton:hover svg path,
395
+ .light-mode #sendButton:active svg path {
396
+ fill: #fbf7e7;
397
+ stroke: #ffcd07 !important;
398
+ }
399
+
400
+ .message-row {
401
+ width: 100%;
402
+ margin: 8px 0;
403
+ display: flex;
404
+ }
405
+
406
+ .message-user {
407
+ justify-content: flex-end;
408
+ }
409
+
410
+ .message-bot {
411
+ justify-content: flex-start;
412
+ }
413
+
414
+ .message-bubble {
415
+ max-width: 70%;
416
+ padding: 10px 14px;
417
+ border-radius: 16px;
418
+ font-size: 14px;
419
+ transition: background-color 0.5s ease, color 0.5s ease;
420
+ }
421
+
422
+ .message-user .message-bubble {
423
+ background-color: #1b798e;
424
+ color: white;
425
+ border-bottom-right-radius: 4px;
426
+ border: 1px solid #FFD700;
427
+ }
428
+
429
+ .message-bot .message-bubble {
430
+ background-color: #2d2d2d;
431
+ color: white;
432
+ border-bottom-left-radius: 4px;
433
+ border: 1px solid #FFD700;
434
+ }
435
+
436
+ .light-mode .message-bot .message-bubble {
437
+ background-color: #fefdf6;
438
+ color: #111;
439
+ }
440
+
441
+ /* βœ… NEW: Image Gallery Styles */
442
+ .image-gallery {
443
+ display: flex;
444
+ gap: 10px;
445
+ flex-wrap: wrap;
446
+ margin-top: 12px;
447
+ padding-top: 12px;
448
+ border-top: 1px solid rgba(255, 215, 0, 0.3);
449
+ }
450
+
451
+ .image-container {
452
+ position: relative;
453
+ border-radius: 8px;
454
+ overflow: hidden;
455
+ cursor: pointer;
456
+ transition: transform 0.2s ease;
457
+ border: 2px solid #FFD700;
458
+ }
459
+
460
+ .image-container:hover {
461
+ transform: scale(1.05);
462
+ }
463
+
464
+ .image-container img {
465
+ width: 150px;
466
+ height: 150px;
467
+ object-fit: cover;
468
+ display: block;
469
+ }
470
+
471
+ .image-caption {
472
+ position: absolute;
473
+ bottom: 0;
474
+ left: 0;
475
+ right: 0;
476
+ background: rgba(0, 0, 0, 0.7);
477
+ color: #FFD700;
478
+ padding: 4px 8px;
479
+ font-size: 10px;
480
+ text-align: center;
481
+ }
482
+
483
+ /* βœ… NEW: Image Modal/Lightbox */
484
+ .image-modal {
485
+ display: none;
486
+ position: fixed;
487
+ z-index: 9999;
488
+ left: 0;
489
+ top: 0;
490
+ width: 100%;
491
+ height: 100%;
492
+ background-color: rgba(0, 0, 0, 0.9);
493
+ align-items: center;
494
+ justify-content: center;
495
+ }
496
+
497
+ .image-modal.active {
498
+ display: flex;
499
+ }
500
+
501
+ .modal-content {
502
+ max-width: 90%;
503
+ max-height: 90%;
504
+ border-radius: 8px;
505
+ box-shadow: 0 4px 20px rgba(255, 215, 0, 0.5);
506
+ }
507
+
508
+ .modal-close {
509
+ position: absolute;
510
+ top: 20px;
511
+ right: 35px;
512
+ color: #FFD700;
513
+ font-size: 40px;
514
+ font-weight: bold;
515
+ cursor: pointer;
516
+ transition: color 0.3s;
517
+ }
518
+
519
+ .modal-close:hover {
520
+ color: #fff;
521
+ }
522
+
523
+ .fileUploadWrapper {
524
+ position: relative;
525
+ display: flex;
526
+ align-items: center;
527
+ margin-right: 8px;
528
+ }
529
+
530
+ .fileUploadWrapper label {
531
+ display: flex;
532
+ align-items: center;
533
+ cursor: pointer;
534
+ padding: 0;
535
+ margin: 0;
536
+ }
537
+
538
+ .fileUploadWrapper svg {
539
+ width: 20px;
540
+ height: 20px;
541
+ fill: #6c6c6c;
542
+ transition: all 0.3s ease;
543
+ }
544
+
545
+ .fileUploadWrapper label:hover svg {
546
+ fill: #10a37f;
547
+ }
548
+
549
+ .fileUploadWrapper .tooltip {
550
+ display: none;
551
+ position: absolute;
552
+ bottom: 125%;
553
+ left: 50%;
554
+ transform: translateX(-50%);
555
+ background-color: #000;
556
+ color: #fff;
557
+ padding: 5px 10px;
558
+ border-radius: 6px;
559
+ font-size: 12px;
560
+ white-space: nowrap;
561
+ z-index: 1;
562
+ }
563
+
564
+ .fileUploadWrapper label:hover .tooltip {
565
+ display: block;
566
+ }
567
+
568
+ .fileUploadWrapper input {
569
+ display: none;
570
+ }
571
+
572
+ .fileUploadWrapper,
573
+ #sendButton {
574
+ width: 40px;
575
+ height: 40px;
576
+ display: flex;
577
+ align-items: center;
578
+ justify-content: center;
579
+ margin: 0 8px;
580
+ }
581
+
582
+ .fileUploadWrapper svg,
583
+ #sendButton svg {
584
+ width: 20px;
585
+ height: 20px;
586
+ fill: #6c6c6c;
587
+ transition: all 0.3s ease;
588
+ }
589
+
590
+ .fileUploadWrapper label:hover svg path,
591
+ .fileUploadWrapper label:hover svg circle,
592
+ #sendButton:hover svg path {
593
+ stroke: white;
594
+ transition: stroke 0.3s ease;
595
+ }
596
+
597
+ .light-mode .fileUploadWrapper label:hover svg path,
598
+ .light-mode .fileUploadWrapper label:hover svg circle,
599
+ .light-mode #sendButton:hover svg path {
600
+ stroke: #ffcd07;
601
+ transition: stroke 0.3s ease;
602
+ }
603
+
604
+ #sendButton svg,
605
+ .fileUploadWrapper svg {
606
+ transition: all 0.3s ease;
607
+ }
608
+
609
+ /* Mode Dropdown Styles */
610
+ .mode-dropdown-wrapper {
611
+ position: relative;
612
+ display: flex;
613
+ align-items: center;
614
+ margin: 0 8px;
615
+ }
616
+
617
+ .mode-dropdown-button {
618
+ display: inline-flex;
619
+ justify-content: center;
620
+ align-items: center;
621
+ padding: 6px 12px;
622
+ background-color: transparent;
623
+ border: none;
624
+ color: #e5e5e5;
625
+ font-size: 14px;
626
+ font-family: 'Roboto', sans-serif;
627
+ font-weight: 500;
628
+ cursor: pointer;
629
+ border-radius: 6px;
630
+ white-space: nowrap;
631
+ transition: background-color 0.3s, color 0.5s ease;
632
+ }
633
+
634
+ .light-mode .mode-dropdown-button {
635
+ background-color: #f9f9f9;
636
+ color: #171717;
637
+ }
638
+
639
+ .mode-dropdown-button:hover {
640
+ background-color: rgba(110, 110, 110, 0.12);
641
+ }
642
+
643
+ .light-mode .mode-dropdown-button:hover {
644
+ background-color: #f0f0f0;
645
+ }
646
+
647
+ .dropdown-arrow {
648
+ width: 16px;
649
+ height: 16px;
650
+ margin-left: 6px;
651
+ transition: transform 0.2s ease, color 0.5s ease;
652
+ }
653
+
654
+ .mode-dropdown-button[aria-expanded="true"] .dropdown-arrow {
655
+ transform: rotate(180deg);
656
+ }
657
+
658
+ .mode-dropdown-menu {
659
+ position: absolute;
660
+ bottom: 100%;
661
+ right: 0;
662
+ margin-bottom: 8px;
663
+ min-width: 180px;
664
+ background-color: #2d2d2d;
665
+ border: 1px solid rgb(63, 63, 63);
666
+ border-radius: 8px;
667
+ box-shadow: 0px 4px 12px rgba(0, 0, 0, 0.4);
668
+ z-index: 1000;
669
+ opacity: 0;
670
+ transform: translateY(-8px);
671
+ transition: opacity 0.2s ease, transform 0.2s ease, background-color 0.5s ease, border-color 0.5s ease;
672
+ pointer-events: none;
673
+ }
674
+
675
+ .mode-dropdown-menu:not(.hidden) {
676
+ opacity: 1;
677
+ transform: translateY(0);
678
+ pointer-events: auto;
679
+ }
680
+
681
+ .light-mode .mode-dropdown-menu {
682
+ background-color: #ffffff;
683
+ border-color: #1d495f;
684
+ box-shadow: 0px 4px 12px rgba(0, 0, 0, 0.15);
685
+ }
686
+
687
+ .dropdown-items {
688
+ padding: 4px 0;
689
+ }
690
+
691
+ .dropdown-item {
692
+ display: flex;
693
+ align-items: center;
694
+ padding: 10px 16px;
695
+ color: #ddd;
696
+ text-decoration: none;
697
+ font-size: 14px;
698
+ font-family: 'Roboto', sans-serif;
699
+ cursor: pointer;
700
+ transition: background-color 0.15s ease, color 0.5s ease;
701
+ }
702
+
703
+ .dropdown-item:first-child {
704
+ border-top-left-radius: 8px;
705
+ border-top-right-radius: 8px;
706
+ }
707
+
708
+ .dropdown-item:last-child {
709
+ border-bottom-left-radius: 8px;
710
+ border-bottom-right-radius: 8px;
711
+ }
712
+
713
+ .dropdown-item:hover {
714
+ background-color: rgba(110, 110, 110, 0.25);
715
+ color: #fff;
716
+ }
717
+
718
+ .light-mode .dropdown-item {
719
+ color: #1f2937;
720
+ }
721
+
722
+ .light-mode .dropdown-item:hover {
723
+ background-color: #f3f4f6;
724
+ color: #1e40af;
725
+ }
726
+
727
+ .dropdown-item .font-semibold {
728
+ font-weight: 600;
729
+ }
730
+
731
+ .hidden {
732
+ display: none;
733
+ }
734
+ </style>
735
+ </head>
736
+ <body>
737
+ <div class="sidebar collapsed">
738
+ <div class="tooltip">
739
+ <span class="tooltiptext">Open Sidebar</span>
740
+ <button id="sidebar-toggle">
741
+ <i class="fas fa-chevron-right"></i>
742
+ </button>
743
+ </div>
744
+ <div class="sidebar-content">
745
+ <div class="conversation-list">
746
+ <div class="conversation">
747
+ <p class="conversation-text">Last Conversation:</p>
748
+ <p class="conversation-content">No conversation yet</p>
749
+ </div>
750
+ </div>
751
+ <button id="new-conversation-btn">Start New Conversation</button>
752
+ </div>
753
+ </div>
754
+
755
+ <div class="chat-container light-mode">
756
+ <div class="chat-content">
757
+ <div class="chat-header">
758
+ <div class="logo-container">
759
+ <h1>BeRU&nbsp;</h1>
760
+ </div>
761
+ <div class="toggle-switch">
762
+ <label class="switch-label">
763
+ <input type="checkbox" id="toggle-checkbox" class="checkbox">
764
+ <span class="slider"></span>
765
+ </label>
766
+ </div>
767
+ </div>
768
+
769
+ <div id="chat-box" class="chat-box"></div>
770
+
771
+ <div class="messageBox">
772
+ <div class="fileUploadWrapper">
773
+ <label for="file">
774
+ <svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 337 337">
775
+ <circle stroke-width="20" stroke="#6c6c6c" fill="none" r="158.5" cy="168.5" cx="168.5"></circle>
776
+ <path stroke-linecap="round" stroke-width="25" stroke="#6c6c6c" d="M167.759 79V259"></path>
777
+ <path stroke-linecap="round" stroke-width="25" stroke="#6c6c6c" d="M79 167.138H259"></path>
778
+ </svg>
779
+ <span class="tooltip">Add an image</span>
780
+ </label>
781
+ <input type="file" id="file" name="file" />
782
+ </div>
783
+
784
+ <input required="" placeholder="Message..." type="text" id="messageInput" />
785
+
786
+ <div class="mode-dropdown-wrapper">
787
+ <button id="modeDropdownButton" type="button" class="mode-dropdown-button" aria-expanded="false" aria-haspopup="true">
788
+ Detailed
789
+ <svg class="dropdown-arrow" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" aria-hidden="true">
790
+ <path fill-rule="evenodd" d="M5.293 7.293a1 1 0 011.414 0L10 10.586l3.293-3.293a1 1 0 111.414 1.414l-4 4a1 1 0 01-1.414 0l-4-4a1 1 0 010-1.414z" clip-rule="evenodd" />
791
+ </svg>
792
+ </button>
793
+ <div id="modeDropdownMenu" class="mode-dropdown-menu hidden" role="menu" aria-orientation="vertical" aria-labelledby="modeDropdownButton" tabindex="-1">
794
+ <div class="dropdown-items">
795
+ <a href="#" class="dropdown-item" role="menuitem" tabindex="-1">
796
+ <span class="font-semibold">Short and Concise</span>
797
+ </a>
798
+ <a href="#" class="dropdown-item" role="menuitem" tabindex="-1">
799
+ <span class="font-semibold">Detailed</span>
800
+ </a>
801
+ <a href="#" class="dropdown-item" role="menuitem" tabindex="-1">
802
+ <span class="font-semibold">Step-by-Step</span>
803
+ </a>
804
+ </div>
805
+ </div>
806
+ </div>
807
+
808
+ <button id="sendButton">
809
+ <svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 664 663">
810
+ <path fill="none" d="M646.293 331.888L17.7538 17.6187L155.245 331.888M646.293 331.888L17.753 646.157L155.245 331.888M646.293 331.888L318.735 330.228L155.245 331.888"></path>
811
+ <path stroke-linejoin="round" stroke-linecap="round" stroke-width="33.67" stroke="#6c6c6c" d="M646.293 331.888L17.7538 17.6187L155.245 331.888M646.293 331.888L17.753 646.157L155.245 331.888M646.293 331.888L318.735 330.228L155.245 331.888"></path>
812
+ </svg>
813
+ </button>
814
+ </div>
815
+ </div>
816
+ </div>
817
+
818
+ <!-- βœ… NEW: Image Modal -->
819
+ <div id="imageModal" class="image-modal">
820
+ <span class="modal-close" id="modalClose">&times;</span>
821
+ <img class="modal-content" id="modalImage">
822
+ </div>
823
+
824
+ <script>
825
+ const chatBox = document.getElementById('chat-box');
826
+ const userInput = document.getElementById('messageInput');
827
+ const sendButton = document.getElementById('sendButton');
828
+ const sidebarToggle = document.getElementById('sidebar-toggle');
829
+ const modeToggle = document.getElementById('toggle-checkbox');
830
+ const sidebar = document.querySelector('.sidebar');
831
+ const chatContainer = document.querySelector('.chat-container');
832
+ const messageBox = document.querySelector('.messageBox');
833
+ const imageModal = document.getElementById('imageModal');
834
+ const modalImage = document.getElementById('modalImage');
835
+ const modalClose = document.getElementById('modalClose');
836
+
837
+ let currentMode = 'Detailed';
838
+ let sessionId = 'session-' + Date.now();
839
+
840
+ // βœ… NEW: Image Modal Functions
841
+ function openImageModal(imageSrc) {
842
+ imageModal.classList.add('active');
843
+ modalImage.src = imageSrc;
844
+ }
845
+
846
+ function closeImageModal() {
847
+ imageModal.classList.remove('active');
848
+ modalImage.src = '';
849
+ }
850
+
851
+ modalClose.addEventListener('click', closeImageModal);
852
+ imageModal.addEventListener('click', (e) => {
853
+ if (e.target === imageModal) {
854
+ closeImageModal();
855
+ }
856
+ });
857
+
858
+ // Utility Functions
859
+ function getCurrentTime() {
860
+ const now = new Date();
861
+ return now.toLocaleTimeString();
862
+ }
863
+
864
+ // βœ… MODIFIED: Chat Functions with Image Support
865
+ function appendMessage(sender, message, images = []) {
866
+ const wrapper = document.createElement('div');
867
+ wrapper.classList.add('message-row');
868
+ wrapper.classList.add(sender === 'user' ? 'message-user' : 'message-bot');
869
+
870
+ const bubble = document.createElement('div');
871
+ bubble.classList.add('message-bubble');
872
+
873
+ // Add text message
874
+ bubble.innerHTML = message.replace(/\n/g, '<br>');
875
+
876
+ // βœ… NEW: Add images if present
877
+ if (images && images.length > 0) {
878
+ const gallery = document.createElement('div');
879
+ gallery.classList.add('image-gallery');
880
+
881
+ images.forEach((img, index) => {
882
+ const imgContainer = document.createElement('div');
883
+ imgContainer.classList.add('image-container');
884
+
885
+ const imgElement = document.createElement('img');
886
+ imgElement.src = img.data;
887
+ imgElement.alt = `Image from ${img.source}`;
888
+ imgElement.loading = 'lazy';
889
+
890
+ // Click to open modal
891
+ imgElement.addEventListener('click', () => {
892
+ openImageModal(img.data);
893
+ });
894
+
895
+ const caption = document.createElement('div');
896
+ caption.classList.add('image-caption');
897
+ caption.textContent = `πŸ“„ ${img.source} | Page ${img.page}`;
898
+
899
+ imgContainer.appendChild(imgElement);
900
+ imgContainer.appendChild(caption);
901
+ gallery.appendChild(imgContainer);
902
+ });
903
+
904
+ bubble.appendChild(gallery);
905
+ }
906
+
907
+ wrapper.appendChild(bubble);
908
+ chatBox.appendChild(wrapper);
909
+ chatBox.scrollTop = chatBox.scrollHeight;
910
+ }
911
+
912
+ async function sendMessage() {
913
+ const message = userInput.value.trim();
914
+ if (message === '') return;
915
+
916
+ // Append user message
917
+ appendMessage('user', message);
918
+ userInput.value = '';
919
+
920
+ // Show loading message
921
+ appendMessage('ChatGPT', '⏳ Thinking...');
922
+
923
+ try {
924
+ const response = await fetch('/api/chat', {
925
+ method: 'POST',
926
+ headers: {
927
+ 'Content-Type': 'application/json',
928
+ },
929
+ body: JSON.stringify({
930
+ message: message,
931
+ mode: currentMode,
932
+ session_id: sessionId,
933
+ include_images: true // βœ… NEW: Request images
934
+ })
935
+ });
936
+
937
+ const data = await response.json();
938
+
939
+ // Remove loading message
940
+ chatBox.removeChild(chatBox.lastChild);
941
+
942
+ if (data.error) {
943
+ appendMessage('ChatGPT', '❌ Error: ' + data.error);
944
+ } else {
945
+ // βœ… NEW: Pass images to appendMessage
946
+ const images = data.images || [];
947
+ appendMessage('ChatGPT', data.response, images);
948
+
949
+ // Log image info
950
+ if (images.length > 0) {
951
+ console.log(`πŸ“· Received ${images.length} images`);
952
+ }
953
+ }
954
+ } catch (error) {
955
+ // Remove loading message
956
+ chatBox.removeChild(chatBox.lastChild);
957
+ appendMessage('ChatGPT', '❌ Connection error. Please check if the server is running.');
958
+ console.error('Error:', error);
959
+ }
960
+ }
961
+
962
+ // Event Listeners
963
+ document.addEventListener('DOMContentLoaded', function() {
964
+ const newConversationBtn = document.getElementById('new-conversation-btn');
965
+ const conversationContent = document.querySelector('.conversation-content');
966
+
967
+ // Sidebar Toggle
968
+ sidebarToggle.addEventListener('click', function() {
969
+ sidebar.classList.toggle('collapsed');
970
+ adjustMessageBoxPosition();
971
+ });
972
+
973
+ function adjustMessageBoxPosition() {
974
+ const sidebarWidth = sidebar.classList.contains('collapsed') ? 50 : 300;
975
+ const chatAreaWidth = window.innerWidth - sidebarWidth;
976
+ messageBox.style.left = sidebarWidth + (chatAreaWidth / 2) + 'px';
977
+ messageBox.style.transform = 'translateX(-50%)';
978
+ }
979
+
980
+ adjustMessageBoxPosition();
981
+ window.addEventListener('resize', adjustMessageBoxPosition);
982
+
983
+ // New Conversation
984
+ newConversationBtn.addEventListener('click', async function() {
985
+ conversationContent.textContent = 'New Conversation Started!';
986
+ chatBox.innerHTML = '';
987
+ sessionId = 'session-' + Date.now();
988
+
989
+ try {
990
+ await fetch('/api/new-conversation', {
991
+ method: 'POST',
992
+ headers: {'Content-Type': 'application/json'},
993
+ body: JSON.stringify({session_id: sessionId})
994
+ });
995
+ } catch (error) {
996
+ console.error('Error starting new conversation:', error);
997
+ }
998
+
999
+ adjustMessageBoxPosition();
1000
+ });
1001
+
1002
+ // Theme Toggle
1003
+ modeToggle.addEventListener('change', function() {
1004
+ document.body.classList.toggle('dark-mode');
1005
+ document.body.classList.toggle('light-mode');
1006
+ chatContainer.classList.toggle('light-mode');
1007
+ chatContainer.classList.toggle('dark-mode');
1008
+ adjustMessageBoxPosition();
1009
+ });
1010
+
1011
+ // Set initial mode
1012
+ document.body.classList.add('light-mode');
1013
+
1014
+ // Send button
1015
+ sendButton.addEventListener('click', sendMessage);
1016
+ userInput.addEventListener('keydown', (event) => {
1017
+ if (event.key === 'Enter') sendMessage();
1018
+ });
1019
+
1020
+ // Mode Dropdown
1021
+ const modeDropdownButton = document.getElementById('modeDropdownButton');
1022
+ const modeDropdownMenu = document.getElementById('modeDropdownMenu');
1023
+
1024
+ if (modeDropdownButton && modeDropdownMenu) {
1025
+ function toggleMenu() {
1026
+ const isHidden = modeDropdownMenu.classList.contains('hidden');
1027
+ if (isHidden) {
1028
+ modeDropdownMenu.classList.remove('hidden');
1029
+ modeDropdownButton.setAttribute('aria-expanded', 'true');
1030
+ } else {
1031
+ modeDropdownMenu.classList.add('hidden');
1032
+ modeDropdownButton.setAttribute('aria-expanded', 'false');
1033
+ }
1034
+ }
1035
+
1036
+ modeDropdownButton.addEventListener('click', (event) => {
1037
+ event.stopPropagation();
1038
+ toggleMenu();
1039
+ });
1040
+
1041
+ document.addEventListener('click', (event) => {
1042
+ if (!modeDropdownButton.contains(event.target) && !modeDropdownMenu.contains(event.target)) {
1043
+ if (!modeDropdownMenu.classList.contains('hidden')) {
1044
+ toggleMenu();
1045
+ }
1046
+ }
1047
+ });
1048
+
1049
+ modeDropdownMenu.querySelectorAll('.dropdown-item').forEach(item => {
1050
+ item.addEventListener('click', (event) => {
1051
+ event.preventDefault();
1052
+ const selectedMode = event.currentTarget.querySelector('.font-semibold').textContent.trim();
1053
+ console.log('Mode selected:', selectedMode);
1054
+
1055
+ currentMode = selectedMode;
1056
+
1057
+ const buttonTextNode = Array.from(modeDropdownButton.childNodes).find(node =>
1058
+ node.nodeType === Node.TEXT_NODE && node.textContent.trim() !== ''
1059
+ );
1060
+
1061
+ if (buttonTextNode) {
1062
+ buttonTextNode.textContent = selectedMode;
1063
+ } else {
1064
+ const textNode = document.createTextNode(selectedMode);
1065
+ modeDropdownButton.insertBefore(textNode, modeDropdownButton.querySelector('.dropdown-arrow'));
1066
+ }
1067
+
1068
+ toggleMenu();
1069
+ });
1070
+ });
1071
+ }
1072
+ });
1073
+ </script>
1074
+ </body>
1075
+ </html>
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.28.0
2
+ torch==2.0.0
3
+ transformers==4.36.0
4
+ langchain==0.1.0
5
+ langchain-community==0.0.10
6
+ langchain-core==0.1.8
7
+ faiss-cpu==1.7.4
8
+ pydantic==2.5.0
9
+ numpy==1.24.3
10
+ dill==0.3.7
11
+ bitsandbytes==0.41.1
12
+ flashrank==0.2.0
13
+ PyMuPDF==1.23.8
14
+ Pillow==10.0.1
15
+ pytesseract==0.3.10
16
+ pdf2image==1.16.3
17
+ rank-bm25==0.2.2
18
+ huggingface-hub==0.18.0
19
+ peft==0.4.0
spaces_app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BeRU RAG Chat App - Optimized for Hugging Face Spaces
3
+ Deployment: https://huggingface.co/spaces/AnwinMJ/Beru
4
+ """
5
+
6
+ import streamlit as st
7
+ import torch
8
+ import os
9
+ import pickle
10
+ import faiss
11
+ import numpy as np
12
+ from transformers import AutoModel, AutoProcessor, AutoTokenizer
13
+ from typing import List, Dict
14
+ import time
15
+ import logging
16
+
17
+ # Setup logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # ========================================
22
+ # 🎨 STREAMLIT PAGE CONFIG
23
+ # ========================================
24
+ st.set_page_config(
25
+ page_title="BeRU Chat - RAG Assistant",
26
+ page_icon="πŸ€–",
27
+ layout="wide",
28
+ initial_sidebar_state="expanded",
29
+ menu_items={
30
+ "About": "BeRU - Offline RAG System with VLM2Vec and Mistral 7B"
31
+ }
32
+ )
33
+
34
+ # ========================================
35
+ # 🌍 ENVIRONMENT DETECTION
36
+ # ========================================
37
+ def detect_environment():
38
+ """Detect if running on HF Spaces"""
39
+ is_spaces = os.getenv('SPACES', 'false').lower() == 'true' or 'huggingface' in os.path.exists('/app')
40
+ return {
41
+ 'is_spaces': is_spaces,
42
+ 'device': 'cuda' if torch.cuda.is_available() else 'cpu',
43
+ 'model_cache': os.getenv('HF_HOME', './cache'),
44
+ 'gpu_memory': torch.cuda.get_device_properties(0).total_memory if torch.cuda.is_available() else 0
45
+ }
46
+
47
+ env_info = detect_environment()
48
+
49
+ # Display environment info in sidebar
50
+ with st.sidebar:
51
+ st.write("### System Info")
52
+ st.write(f"πŸ–₯️ Device: `{env_info['device'].upper()}`")
53
+ if env_info['device'] == 'cuda':
54
+ st.write(f"πŸ’Ύ GPU VRAM: `{env_info['gpu_memory'] / 1e9:.1f} GB`")
55
+ st.write(f"πŸ“¦ Cache: `{env_info['model_cache']}`")
56
+
57
+
58
+ # ========================================
59
+ # 🎯 MODEL LOADING WITH CACHING
60
+ # ========================================
61
+ @st.cache_resource
62
+ def load_embedding_model():
63
+ """Load VLM2Vec embedding model with error handling"""
64
+ with st.spinner("⏳ Loading embedding model... (first time may take 5 min)"):
65
+ try:
66
+ logger.info("Loading VLM2Vec model...")
67
+ device = "cuda" if torch.cuda.is_available() else "cpu"
68
+
69
+ model = AutoModel.from_pretrained(
70
+ "TIGER-Lab/VLM2Vec-Qwen2VL-2B",
71
+ trust_remote_code=True,
72
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
73
+ cache_dir=env_info['model_cache']
74
+ ).to(device)
75
+
76
+ processor = AutoProcessor.from_pretrained(
77
+ "TIGER-Lab/VLM2Vec-Qwen2VL-2B",
78
+ trust_remote_code=True,
79
+ cache_dir=env_info['model_cache']
80
+ )
81
+
82
+ tokenizer = AutoTokenizer.from_pretrained(
83
+ "TIGER-Lab/VLM2Vec-Qwen2VL-2B",
84
+ trust_remote_code=True,
85
+ cache_dir=env_info['model_cache']
86
+ )
87
+
88
+ model.eval()
89
+ logger.info("βœ… Embedding model loaded successfully")
90
+ st.success("βœ… Embedding model loaded!")
91
+
92
+ return model, processor, tokenizer, device
93
+
94
+ except Exception as e:
95
+ st.error(f"❌ Error loading embedding model: {str(e)}")
96
+ logger.error(f"Model loading error: {e}")
97
+ raise
98
+
99
+
100
+ @st.cache_resource
101
+ def load_llm_model():
102
+ """Load Mistral 7B LLM with quantization"""
103
+ with st.spinner("⏳ Loading LLM model... (first time may take 5 min)"):
104
+ try:
105
+ logger.info("Loading Mistral-7B model...")
106
+ device = "cuda" if torch.cuda.is_available() else "cpu"
107
+
108
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
109
+
110
+ # 4-bit quantization config for memory efficiency
111
+ bnb_config = BitsAndBytesConfig(
112
+ load_in_4bit=True,
113
+ bnb_4bit_use_double_quant=True,
114
+ bnb_4bit_quant_type="nf4",
115
+ bnb_4bit_compute_dtype=torch.bfloat16
116
+ )
117
+
118
+ tokenizer = AutoTokenizer.from_pretrained(
119
+ "mistralai/Mistral-7B-Instruct-v0.3",
120
+ cache_dir=env_info['model_cache']
121
+ )
122
+
123
+ model = AutoModelForCausalLM.from_pretrained(
124
+ "mistralai/Mistral-7B-Instruct-v0.3",
125
+ quantization_config=bnb_config,
126
+ device_map="auto",
127
+ cache_dir=env_info['model_cache']
128
+ )
129
+
130
+ logger.info("βœ… LLM model loaded successfully")
131
+ st.success("βœ… LLM model loaded!")
132
+
133
+ return model, tokenizer, device
134
+
135
+ except Exception as e:
136
+ st.error(f"❌ Error loading LLM: {str(e)}")
137
+ logger.error(f"LLM loading error: {e}")
138
+ raise
139
+
140
+
141
+ # ========================================
142
+ # 🏠 UI LAYOUT
143
+ # ========================================
144
+ st.title("οΏ½οΏ½ BeRU Chat - RAG Assistant")
145
+ st.markdown("""
146
+ A powerful offline RAG system combining Mistral 7B LLM with VLM2Vec embeddings
147
+ for intelligent document search and conversation.
148
+
149
+ **Status**: Models loading on first access (5-8 minutes)
150
+ """)
151
+
152
+ # Load models
153
+ try:
154
+ embedding_model, processor, tokenizer, device = load_embedding_model()
155
+ llm_model, llm_tokenizer, llm_device = load_llm_model()
156
+ models_loaded = True
157
+ except Exception as e:
158
+ st.error(f"Failed to load models: {str(e)}")
159
+ models_loaded = False
160
+
161
+ if models_loaded:
162
+ # Main chat interface
163
+ left_col, right_col = st.columns([2, 1])
164
+
165
+ with left_col:
166
+ st.subheader("πŸ’¬ Chat")
167
+
168
+ # Initialize session state
169
+ if "messages" not in st.session_state:
170
+ st.session_state.messages = []
171
+
172
+ # Display chat history
173
+ for msg in st.session_state.messages:
174
+ with st.chat_message(msg["role"]):
175
+ st.write(msg["content"])
176
+
177
+ # Chat input
178
+ user_input = st.chat_input("Ask a question about your documents...")
179
+
180
+ if user_input:
181
+ # Add user message
182
+ st.session_state.messages.append({"role": "user", "content": user_input})
183
+
184
+ with st.chat_message("user"):
185
+ st.write(user_input)
186
+
187
+ # Generate response
188
+ with st.chat_message("assistant"):
189
+ with st.spinner("πŸ€” Thinking..."):
190
+ # Placeholder for RAG response
191
+ response = "Response generated from RAG system..."
192
+ st.write(response)
193
+ st.session_state.messages.append({"role": "assistant", "content": response})
194
+
195
+ with right_col:
196
+ st.subheader("πŸ“Š Info")
197
+ st.info("""
198
+ **Model Info:**
199
+ - 🧠 Embedding: VLM2Vec-Qwen2VL-2B
200
+ - πŸ’¬ LLM: Mistral-7B-Instruct
201
+ - πŸ” Search: FAISS + BM25
202
+
203
+ **Performance:**
204
+ - Device: GPU if available
205
+ - Quantization: 4-bit
206
+ - Context: Multi-turn
207
+ """)
208
+
209
+ st.subheader("βš™οΈ Settings")
210
+ temperature = st.slider("Temperature", 0.0, 1.0, 0.7)
211
+ max_tokens = st.slider("Max Tokens", 100, 2000, 512)
212
+
213
+ else:
214
+ st.error("❌ Failed to initialize models. Check logs for details.")
215
+ st.info("Try refreshing the page or restarting the Space.")
216
+
217
+ # ========================================
218
+ # πŸ“ FOOTER
219
+ # ========================================
220
+ st.markdown("---")
221
+ st.markdown("""
222
+ <div style='text-align: center'>
223
+ <small>
224
+ BeRU RAG System |
225
+ <a href='https://huggingface.co/spaces/AnwinMJ/Beru'>Space</a> |
226
+ <a href='https://github.com/AnwinMJ/BeRU'>GitHub</a>
227
+ </small>
228
+ </div>
229
+ """, unsafe_allow_html=True)
vlm2rag2.py ADDED
@@ -0,0 +1,1354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import glob
3
+ import os
4
+ import gc
5
+ import time
6
+ import re
7
+ import hashlib
8
+ from pathlib import Path
9
+ from typing import List, Dict, Tuple, Optional
10
+ import fitz # PyMuPDF
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+ from transformers import AutoModel, AutoProcessor, AutoTokenizer # Changed from AutoModelForCausalLM
15
+ from langchain_core.documents import Document
16
+ import pickle
17
+ from numpy.linalg import norm
18
+ import camelot
19
+ import base64
20
+ import pytesseract
21
+ from pdf2image import convert_from_path
22
+ import faiss
23
+ from rank_bm25 import BM25Okapi
24
+
25
+
26
+ # ========================================
27
+ # πŸ“‚ CONFIGURATION
28
+ # ========================================
29
+ PDF_DIR = r"D:\BeRU\testing"
30
+ FAISS_INDEX_PATH = "VLM2Vec-V2rag2"
31
+ MODEL_CACHE_DIR = ".cache"
32
+ IMAGE_OUTPUT_DIR = "extracted_images2"
33
+
34
+ # Chunking configuration
35
+ CHUNK_SIZE = 450 # words
36
+ OVERLAP = 100 # words
37
+ MIN_CHUNK_SIZE = 50
38
+ MAX_CHUNK_SIZE = 800
39
+
40
+ # Instruction prefixes for better embeddings
41
+ DOCUMENT_INSTRUCTION = "Represent this technical document for semantic search: "
42
+ QUERY_INSTRUCTION = "Represent this question for finding relevant technical information: "
43
+
44
+ # Hybrid search weights
45
+ DENSE_WEIGHT = 0.4 # Weight for semantic search
46
+ SPARSE_WEIGHT = 0.6 # Weight for keyword search
47
+
48
+ # Create directories
49
+ os.makedirs(PDF_DIR, exist_ok=True)
50
+ os.makedirs(FAISS_INDEX_PATH, exist_ok=True)
51
+ os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
52
+ os.makedirs(IMAGE_OUTPUT_DIR, exist_ok=True)
53
+
54
+ # ========================================
55
+ # πŸ€– VLM2Vec-V2 WRAPPER (ENHANCED)
56
+ # ========================================
57
+ class VLM2VecEmbeddings:
58
+ """VLM2Vec-V2 embedding class with instruction prefixes."""
59
+
60
+ def __init__(self, model_name: str = "TIGER-Lab/VLM2Vec-Qwen2VL-2B", cache_dir: str = None):
61
+ print(f"πŸ€– Loading VLM2Vec-V2 model: {model_name}")
62
+
63
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
64
+ print(f" Device: {self.device}")
65
+
66
+ try:
67
+ self.model = AutoModel.from_pretrained(
68
+ model_name,
69
+ cache_dir=cache_dir,
70
+ trust_remote_code=True,
71
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
72
+ ).to(self.device)
73
+
74
+ self.processor = AutoProcessor.from_pretrained(
75
+ model_name,
76
+ cache_dir=cache_dir,
77
+ trust_remote_code=True
78
+ )
79
+
80
+ self.tokenizer = AutoTokenizer.from_pretrained(
81
+ model_name,
82
+ cache_dir=cache_dir,
83
+ trust_remote_code=True
84
+ )
85
+
86
+ self.model.eval()
87
+
88
+ # Get actual embedding dimension
89
+ test_input = self.tokenizer("test", return_tensors="pt").to(self.device)
90
+ with torch.no_grad():
91
+ test_output = self.model(**test_input, output_hidden_states=True)
92
+ self.embedding_dim = test_output.hidden_states[-1].shape[-1]
93
+
94
+ print(f" Embedding dimension: {self.embedding_dim}")
95
+ print("βœ… VLM2Vec-V2 loaded successfully\n")
96
+
97
+ except Exception as e:
98
+ print(f"❌ Error loading VLM2Vec-V2: {e}")
99
+ raise
100
+
101
+ def normalize_text(self, text: str) -> str:
102
+ """Normalize text for better embeddings."""
103
+ # Remove excessive whitespace
104
+ text = re.sub(r'\s+', ' ', text)
105
+
106
+ # Remove page numbers
107
+ text = re.sub(r'Page \d+', '', text, flags=re.IGNORECASE)
108
+
109
+ # Normalize unicode
110
+ text = text.strip()
111
+
112
+ return text
113
+
114
+ def embed_documents(self, texts: List[str], add_instruction: bool = True) -> List[List[float]]:
115
+ """Embed documents with instruction prefix and weighted mean pooling."""
116
+ embeddings = []
117
+
118
+ with torch.no_grad():
119
+ for text in texts:
120
+ try:
121
+ # βœ… NORMALIZE TEXT
122
+ clean_text = self.normalize_text(text)
123
+
124
+ # βœ… ADD INSTRUCTION PREFIX
125
+ if add_instruction:
126
+ prefixed_text = DOCUMENT_INSTRUCTION + clean_text
127
+ else:
128
+ prefixed_text = clean_text
129
+
130
+ inputs = self.tokenizer(
131
+ prefixed_text,
132
+ return_tensors="pt",
133
+ padding=True,
134
+ truncation=True,
135
+ max_length=min(self.tokenizer.model_max_length or 512, 2048)
136
+ ).to(self.device)
137
+
138
+ outputs = self.model(**inputs, output_hidden_states=True)
139
+
140
+ if hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:
141
+ # βœ… WEIGHTED MEAN POOLING (ignores padding)
142
+ hidden_states = outputs.hidden_states[-1]
143
+ attention_mask = inputs['attention_mask'].unsqueeze(-1).float()
144
+
145
+ # Apply attention mask as weights
146
+ weighted_hidden_states = hidden_states * attention_mask
147
+ sum_embeddings = weighted_hidden_states.sum(dim=1)
148
+ sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9)
149
+
150
+ # Weighted mean
151
+ embedding = (sum_embeddings / sum_mask).squeeze()
152
+ else:
153
+ # Fallback to logits
154
+ attention_mask = inputs['attention_mask'].unsqueeze(-1).float()
155
+ weighted_logits = outputs.logits * attention_mask
156
+ sum_embeddings = weighted_logits.sum(dim=1)
157
+ sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9)
158
+ embedding = (sum_embeddings / sum_mask).squeeze()
159
+
160
+ embeddings.append(embedding.cpu().numpy().tolist())
161
+
162
+ except Exception as e:
163
+ print(f" ❌ CRITICAL: Failed to embed text: {e}")
164
+ print(f" Text preview: {text[:100]}")
165
+ raise RuntimeError(f"Embedding failed for text: {text[:50]}...") from e
166
+
167
+ return embeddings
168
+
169
+ def embed_query(self, text: str) -> List[float]:
170
+ """Embed query with query-specific instruction."""
171
+ # βœ… DIFFERENT INSTRUCTION FOR QUERIES
172
+ clean_text = self.normalize_text(text)
173
+ prefixed_text = QUERY_INSTRUCTION + clean_text
174
+
175
+ # Don't add document instruction again
176
+ return self.embed_documents([prefixed_text], add_instruction=False)[0]
177
+
178
+ def embed_image(self, image_path: str, prompt: str = "Technical diagram") -> Optional[List[float]]:
179
+ """Embed image with Qwen2-VL proper format."""
180
+ try:
181
+ with torch.no_grad():
182
+ image = Image.open(image_path).convert('RGB')
183
+
184
+ # βœ… QWEN2-VL CORRECT FORMAT
185
+ messages = [
186
+ {
187
+ "role": "user",
188
+ "content": [
189
+ {"type": "image", "image": image},
190
+ {"type": "text", "text": prompt}
191
+ ]
192
+ }
193
+ ]
194
+
195
+ # Apply chat template
196
+ text = self.processor.apply_chat_template(
197
+ messages,
198
+ tokenize=False,
199
+ add_generation_prompt=True
200
+ )
201
+
202
+ # Process with both text and images
203
+ inputs = self.processor(
204
+ text=[text],
205
+ images=[image],
206
+ return_tensors="pt",
207
+ padding=True
208
+ ).to(self.device)
209
+
210
+ outputs = self.model(**inputs, output_hidden_states=True)
211
+
212
+ if hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:
213
+ hidden_states = outputs.hidden_states[-1]
214
+
215
+ # Use weighted mean pooling
216
+ if 'attention_mask' in inputs:
217
+ attention_mask = inputs['attention_mask'].unsqueeze(-1).float()
218
+ weighted_hidden_states = hidden_states * attention_mask
219
+ sum_embeddings = weighted_hidden_states.sum(dim=1)
220
+ sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9)
221
+ embedding = (sum_embeddings / sum_mask).squeeze()
222
+ else:
223
+ embedding = hidden_states.mean(dim=1).squeeze()
224
+ else:
225
+ # Fallback to pooler output if available
226
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
227
+ embedding = outputs.pooler_output.squeeze()
228
+ else:
229
+ return None
230
+
231
+ return embedding.cpu().numpy().tolist()
232
+
233
+ except Exception as e:
234
+ print(f" ⚠️ Failed to embed image {Path(image_path).name}: {str(e)[:100]}")
235
+ return None
236
+
237
+
238
+ # ========================================
239
+ # πŸ” QUERY PREPROCESSING
240
+ # ========================================
241
+ def preprocess_query(query: str) -> str:
242
+ """Preprocess query by expanding abbreviations."""
243
+
244
+ abbreviations = {
245
+ r'\bh2s\b': 'hydrogen sulfide',
246
+ r'\bppm\b': 'parts per million',
247
+ r'\bppe\b': 'personal protective equipment',
248
+ r'\bscba\b': 'self contained breathing apparatus',
249
+ r'\blel\b': 'lower explosive limit',
250
+ r'\bhel\b': 'higher explosive limit',
251
+ r'\buel\b': 'upper explosive limit'
252
+ }
253
+
254
+ query_lower = query.lower()
255
+ for abbr, full in abbreviations.items():
256
+ query_lower = re.sub(abbr, full, query_lower)
257
+
258
+ # Remove excessive punctuation
259
+ query_lower = re.sub(r'[?!]+$', '', query_lower)
260
+
261
+ # Clean extra spaces
262
+ query_lower = re.sub(r'\s+', ' ', query_lower).strip()
263
+
264
+ return query_lower
265
+
266
+
267
+ # ========================================
268
+ # πŸ“Š TABLE EXTRACTION
269
+ # ========================================
270
+ def is_table_of_contents_header(df, page_num):
271
+ """Detect TOC by checking first row for keywords."""
272
+ if len(df) == 0 or page_num > 15:
273
+ return False
274
+
275
+ # Check first row (headers)
276
+ first_row = ' '.join(df.iloc[0].astype(str)).lower()
277
+
278
+ # TOC keywords in your images
279
+ toc_keywords = ['section', 'subsection', 'description', 'page no', 'page number', 'contents']
280
+
281
+ # If at least 2 keywords match, it's TOC
282
+ keyword_count = sum(1 for keyword in toc_keywords if keyword in first_row)
283
+
284
+ return keyword_count >= 2
285
+
286
+
287
+ def looks_like_toc_data(df):
288
+ """Check if table data looks like TOC (section numbers + page numbers)."""
289
+ if len(df) < 2 or len(df.columns) < 2:
290
+ return False
291
+
292
+ # Check last column: should be mostly page numbers (182-246 range in your case)
293
+ last_col = df.iloc[1:, -1].astype(str) # Skip header row
294
+ numeric_count = sum(val.strip().isdigit() and 50 < int(val.strip()) < 300
295
+ for val in last_col if val.strip().isdigit())
296
+
297
+ if len(last_col) > 0 and numeric_count / len(last_col) > 0.7:
298
+ # Check first column: should have section numbers like "10.1", "10.2"
299
+ first_col = df.iloc[1:, 0].astype(str)
300
+ section_pattern = sum(1 for val in first_col
301
+ if re.match(r'^\d+\.?\d*$', val.strip()))
302
+
303
+ if section_pattern / len(first_col) > 0.5:
304
+ return True
305
+
306
+ return False
307
+
308
+
309
+ def extract_tables_from_pdf(pdf_path: str) -> List[Document]:
310
+ """Extract bordered tables with smart TOC detection."""
311
+ chunks = []
312
+
313
+ try:
314
+ lattice_tables = camelot.read_pdf(
315
+ pdf_path,
316
+ pages='all',
317
+ flavor='lattice', # Only bordered tables
318
+ suppress_stdout=True
319
+ )
320
+
321
+ all_tables = list(lattice_tables)
322
+ seen_tables = set()
323
+
324
+ # Track TOC state
325
+ in_toc_section = False
326
+ toc_start_page = None
327
+
328
+ print(f" πŸ“Š Found {len(all_tables)} bordered tables")
329
+
330
+ for table in all_tables:
331
+ df = table.df
332
+ current_page = table.page
333
+
334
+ # Unique ID
335
+ table_id = (current_page, tuple(df.iloc[0].tolist()) if len(df) > 0 else ())
336
+ if table_id in seen_tables:
337
+ continue
338
+ seen_tables.add(table_id)
339
+
340
+ # Skip first 5 pages (title pages)
341
+ if current_page <= 5:
342
+ continue
343
+
344
+ # Basic validation
345
+ if len(df.columns) < 2 or len(df) < 3 or table.accuracy < 80:
346
+ continue
347
+
348
+ # βœ… Detect TOC start (page with header row)
349
+ if not in_toc_section and is_table_of_contents_header(df, current_page):
350
+ in_toc_section = True
351
+ toc_start_page = current_page
352
+ print(f" πŸ” TOC detected at page {current_page}")
353
+ continue
354
+
355
+ # βœ… If we're in TOC section, check if this continues the pattern
356
+ if in_toc_section:
357
+ if looks_like_toc_data(df):
358
+ print(f" ⏭️ Skipping TOC continuation on page {current_page}")
359
+ continue
360
+ else:
361
+ # TOC ended, resume normal extraction
362
+ print(f" βœ… TOC ended, found real table on page {current_page}")
363
+ in_toc_section = False
364
+
365
+ # Extract valid table
366
+ table_text = table_to_natural_language_enhanced(table)
367
+
368
+ if table_text.strip():
369
+ chunks.append(Document(
370
+ page_content=table_text,
371
+ metadata={
372
+ "source": os.path.basename(pdf_path),
373
+ "page": current_page,
374
+ "heading": "Table Data",
375
+ "type": "table",
376
+ "table_accuracy": table.accuracy
377
+ }
378
+ ))
379
+
380
+ print(f" βœ… Extracted {len(chunks)} valid tables (after TOC filtering)")
381
+
382
+ except Exception as e:
383
+ print(f"⚠️ Table extraction failed: {e}")
384
+
385
+ finally:
386
+ try:
387
+ del lattice_tables
388
+ del all_tables
389
+ gc.collect()
390
+ time.sleep(0.1)
391
+ except:
392
+ pass
393
+
394
+ return chunks
395
+
396
+
397
+
398
+ def table_to_natural_language_enhanced(table) -> str:
399
+ """Enhanced table-to-natural-language conversion."""
400
+ df = table.df
401
+
402
+ if len(df) < 2:
403
+ return ""
404
+
405
+ headers = [str(h).strip() for h in df.iloc[0].astype(str).tolist()]
406
+ headers = [h if h and h.lower() not in ['', 'nan', 'none'] else f"Column_{i}"
407
+ for i, h in enumerate(headers)]
408
+
409
+ descriptions = []
410
+
411
+ for idx in range(1, len(df)):
412
+ row = [str(cell).strip() for cell in df.iloc[idx].astype(str).tolist()]
413
+
414
+ if not any(cell and cell.lower() not in ['', 'nan', 'none'] for cell in row):
415
+ continue
416
+
417
+ if len(row) > 0 and row[0] and row[0].lower() not in ['', 'nan', 'none']:
418
+ sentence_parts = []
419
+
420
+ for i in range(1, min(len(row), len(headers))):
421
+ if row[i] and row[i].lower() not in ['', 'nan', 'none']:
422
+ sentence_parts.append(f"{headers[i]}: {row[i]}")
423
+
424
+ if sentence_parts:
425
+ descriptions.append(f"{row[0]} has {', '.join(sentence_parts)}.")
426
+ else:
427
+ descriptions.append(f"{row[0]}.")
428
+
429
+ return "\n".join(descriptions)
430
+
431
+
432
+ def extract_tables_with_ocr(pdf_path: str, page_num: int) -> List[Dict]:
433
+ """OCR fallback for image-based PDFs."""
434
+ try:
435
+ images = convert_from_path(pdf_path, first_page=page_num, last_page=page_num)
436
+
437
+ if not images:
438
+ return []
439
+
440
+ ocr_text = pytesseract.image_to_string(images[0])
441
+ lines = ocr_text.split('\n')
442
+ table_lines = []
443
+
444
+ for line in lines:
445
+ if re.search(r'\s{2,}', line) or '\t' in line:
446
+ table_lines.append(line)
447
+
448
+ if len(table_lines) > 2:
449
+ return [{
450
+ "text": "\n".join(table_lines),
451
+ "page": page_num,
452
+ "method": "ocr"
453
+ }]
454
+
455
+ return []
456
+
457
+ except Exception as e:
458
+ return []
459
+
460
+
461
+ def get_table_regions(pdf_path: str) -> Dict[int, List[tuple]]:
462
+ """Get bounding boxes using BOTH lattice and stream methods."""
463
+ table_regions = {}
464
+
465
+ try:
466
+ lattice_tables = camelot.read_pdf(pdf_path, pages='all', flavor='lattice', suppress_stdout=True)
467
+ stream_tables = camelot.read_pdf(pdf_path, pages='all', flavor='stream', suppress_stdout=True)
468
+
469
+ all_tables = list(lattice_tables) + list(stream_tables)
470
+
471
+ for table in all_tables:
472
+ page = table.page
473
+
474
+ if is_table_of_contents_header(table.df, page):
475
+ continue
476
+
477
+ bbox = table._bbox
478
+
479
+ if page not in table_regions:
480
+ table_regions[page] = []
481
+
482
+ if bbox not in table_regions[page]:
483
+ table_regions[page].append(bbox)
484
+
485
+ except Exception as e:
486
+ pass
487
+
488
+ return table_regions
489
+
490
+
491
+
492
+
493
+ # ========================================
494
+ # πŸ–ΌοΈ IMAGE EXTRACTION
495
+ # ========================================
496
+ def extract_images_from_pdf(pdf_path: str, output_dir: str) -> List[Dict]:
497
+ """Extract images from PDF."""
498
+ doc = fitz.open(pdf_path)
499
+ image_data = []
500
+
501
+ for page_num in range(len(doc)):
502
+ page = doc[page_num]
503
+ images = page.get_images()
504
+
505
+ for img_index, img in enumerate(images):
506
+ try:
507
+ xref = img[0]
508
+ base_image = doc.extract_image(xref)
509
+ image_bytes = base_image["image"]
510
+
511
+ if len(image_bytes) < 10000:
512
+ continue
513
+
514
+ image_filename = f"{Path(pdf_path).stem}_p{page_num+1}_img{img_index+1}.png"
515
+ image_path = os.path.join(output_dir, image_filename)
516
+
517
+ with open(image_path, "wb") as img_file:
518
+ img_file.write(image_bytes)
519
+
520
+ image_data.append({
521
+ "path": image_path,
522
+ "page": page_num + 1,
523
+ "source": os.path.basename(pdf_path),
524
+ "type": "image"
525
+ })
526
+
527
+ except Exception as e:
528
+ continue
529
+
530
+ doc.close()
531
+ return image_data
532
+
533
+
534
+ # ========================================
535
+ # πŸ“„ TEXT EXTRACTION WITH OVERLAPPING CHUNKS
536
+ # ========================================
537
+ def is_bold_text(span):
538
+ return "bold" in span['font'].lower() or (span['flags'] & 2**4)
539
+
540
+
541
+ def is_likely_heading(text, font_size, is_bold, avg_font_size):
542
+ if not is_bold:
543
+ return False
544
+ text = text.strip()
545
+ if len(text) > 100 or len(text) < 3:
546
+ return False
547
+ if font_size > avg_font_size * 1.1:
548
+ return True
549
+ if text.isupper() or re.match(r'^\d+\.?\d*\s+[A-Z]', text):
550
+ return True
551
+ return False
552
+
553
+
554
+ def is_inside_table(block_bbox, table_bboxes):
555
+ """Check if text block overlaps with table region."""
556
+ bx1, by1, bx2, by2 = block_bbox
557
+
558
+ for table_bbox in table_bboxes:
559
+ tx1, ty1, tx2, ty2 = table_bbox
560
+
561
+ if not (bx2 < tx1 or bx1 > tx2 or by2 < ty1 or by1 > ty2):
562
+ return True
563
+
564
+ return False
565
+
566
+
567
+ def split_text_with_overlap(text: str, heading: str, source: str, page: int,
568
+ chunk_size: int = CHUNK_SIZE, overlap: int = OVERLAP) -> List[Document]:
569
+ """Split text with overlap and heading context."""
570
+
571
+ words = text.split()
572
+
573
+ if len(words) <= chunk_size:
574
+ # βœ… ADD HEADING CONTEXT
575
+ content_with_context = f"Section: {heading}\n\n{text}"
576
+
577
+ return [Document(
578
+ page_content=content_with_context,
579
+ metadata={
580
+ "source": source,
581
+ "page": page,
582
+ "heading": heading,
583
+ "type": "text",
584
+ "parent_text": text,
585
+ "chunk_index": 0,
586
+ "total_chunks": 1
587
+ }
588
+ )]
589
+
590
+ chunks = []
591
+ chunk_index = 0
592
+
593
+ for i in range(0, len(words), chunk_size - overlap):
594
+ chunk_words = words[i:i + chunk_size]
595
+
596
+ if len(chunk_words) < MIN_CHUNK_SIZE and len(chunks) > 0:
597
+ break
598
+
599
+ chunk_text = " ".join(chunk_words)
600
+
601
+ # βœ… ADD HEADING CONTEXT TO EACH CHUNK
602
+ content_with_context = f"Section: {heading}\n\n{chunk_text}"
603
+
604
+ chunks.append(Document(
605
+ page_content=content_with_context,
606
+ metadata={
607
+ "source": source,
608
+ "page": page,
609
+ "heading": heading,
610
+ "type": "text",
611
+ "parent_text": text,
612
+ "chunk_index": chunk_index,
613
+ "start_word": i,
614
+ "end_word": i + len(chunk_words)
615
+ }
616
+ ))
617
+
618
+ chunk_index += 1
619
+
620
+ for chunk in chunks:
621
+ chunk.metadata["total_chunks"] = len(chunks)
622
+
623
+ return chunks
624
+
625
+
626
+ def extract_text_chunks_with_overlap(pdf_path: str, table_regions: Dict[int, List[tuple]]) -> List[Document]:
627
+ """Extract text with overlapping chunks."""
628
+ doc = fitz.open(pdf_path)
629
+
630
+ all_font_sizes = []
631
+ for page_num in range(len(doc)):
632
+ page = doc[page_num]
633
+ blocks = page.get_text("dict")["blocks"]
634
+ for block in blocks:
635
+ if "lines" in block:
636
+ for line in block["lines"]:
637
+ for span in line["spans"]:
638
+ all_font_sizes.append(span["size"])
639
+
640
+ avg_font_size = sum(all_font_sizes) / len(all_font_sizes) if all_font_sizes else 12
641
+
642
+ sections = []
643
+ current_section = ""
644
+ current_heading = "Introduction"
645
+ current_page = 1
646
+
647
+ for page_num in range(len(doc)):
648
+ page = doc[page_num]
649
+ blocks = page.get_text("dict")["blocks"]
650
+
651
+ page_tables = table_regions.get(page_num + 1, [])
652
+
653
+ for block in blocks:
654
+ if "lines" not in block:
655
+ continue
656
+
657
+ block_bbox = block.get("bbox", (0, 0, 0, 0))
658
+ if is_inside_table(block_bbox, page_tables):
659
+ continue
660
+
661
+ for line in block["lines"]:
662
+ line_text = ""
663
+ line_is_bold = False
664
+ line_font_size = 0
665
+
666
+ for span in line["spans"]:
667
+ line_text += span["text"]
668
+ if is_bold_text(span):
669
+ line_is_bold = True
670
+ line_font_size = max(line_font_size, span["size"])
671
+
672
+ line_text = line_text.strip()
673
+ if not line_text:
674
+ continue
675
+
676
+ if is_likely_heading(line_text, line_font_size, line_is_bold, avg_font_size):
677
+ if current_section.strip():
678
+ sections.append({
679
+ "text": current_section.strip(),
680
+ "heading": current_heading,
681
+ "page": current_page,
682
+ "source": os.path.basename(pdf_path)
683
+ })
684
+
685
+ current_heading = line_text
686
+ current_section = ""
687
+ current_page = page_num + 1
688
+ else:
689
+ current_section += line_text + " "
690
+
691
+ if current_section.strip():
692
+ sections.append({
693
+ "text": current_section.strip(),
694
+ "heading": current_heading,
695
+ "page": current_page,
696
+ "source": os.path.basename(pdf_path)
697
+ })
698
+
699
+ doc.close()
700
+
701
+ all_chunks = []
702
+
703
+ for section in sections:
704
+ chunks = split_text_with_overlap(
705
+ text=section['text'],
706
+ heading=section['heading'],
707
+ source=section['source'],
708
+ page=section['page'],
709
+ chunk_size=CHUNK_SIZE,
710
+ overlap=OVERLAP
711
+ )
712
+ all_chunks.extend(chunks)
713
+
714
+ return all_chunks
715
+
716
+
717
+ # ========================================
718
+ # πŸ”„ COMBINED EXTRACTION
719
+ # ========================================
720
+ def extract_all_content_from_pdf(pdf_path: str) -> Tuple[List[Document], List[Dict]]:
721
+ """Extract text, tables, and images."""
722
+
723
+ print(f" πŸ“Š Extracting tables...")
724
+ table_regions = get_table_regions(pdf_path)
725
+ table_chunks = extract_tables_from_pdf(pdf_path)
726
+ print(f" βœ… {len(table_chunks)} table chunks")
727
+
728
+ print(f" πŸ“„ Extracting text...")
729
+ text_chunks = extract_text_chunks_with_overlap(pdf_path, table_regions)
730
+ print(f" βœ… {len(text_chunks)} text chunks")
731
+
732
+ print(f" πŸ–ΌοΈ Extracting images...")
733
+ images = extract_images_from_pdf(pdf_path, IMAGE_OUTPUT_DIR)
734
+ print(f" βœ… {len(images)} images")
735
+
736
+ all_chunks = text_chunks + table_chunks
737
+
738
+ return all_chunks, images
739
+
740
+
741
+ # ========================================
742
+ # πŸ—οΈ BUILD FAISS INDEX WITH STREAMING
743
+ # ========================================
744
+ # Replace the HybridRetriever class and related functions with this optimized version:
745
+
746
+ # ========================================
747
+ # πŸ—οΈ BUILD FAISS INDEX WITH BM25
748
+ # ========================================
749
+ def build_multimodal_faiss_streaming(pdf_files: List[str], embedding_model: VLM2VecEmbeddings):
750
+ """Build FAISS index with streaming and BM25."""
751
+
752
+ index_hash_file = f"{FAISS_INDEX_PATH}/index_hash.txt"
753
+ current_hash = hashlib.md5("".join(sorted(pdf_files)).encode()).hexdigest()
754
+
755
+ if os.path.exists(index_hash_file):
756
+ with open(index_hash_file, 'r') as f:
757
+ existing_hash = f.read().strip()
758
+
759
+ if existing_hash == current_hash:
760
+ print("⚠️ Index already exists for these PDFs!")
761
+ response = input(" Rebuild anyway? (yes/no): ").strip().lower()
762
+ if response != 'yes':
763
+ return None, []
764
+
765
+ all_texts = []
766
+ all_image_paths = []
767
+
768
+ print("\nπŸ“„ Processing PDFs...\n")
769
+
770
+ for pdf_file in pdf_files:
771
+ print(f"πŸ“– Processing: {Path(pdf_file).name}")
772
+
773
+ try:
774
+ text_chunks, images = extract_all_content_from_pdf(pdf_file)
775
+
776
+ all_texts.extend(text_chunks)
777
+ all_image_paths.extend(images)
778
+
779
+ except Exception as e:
780
+ print(f" ❌ Error: {e}")
781
+ continue
782
+
783
+ print()
784
+
785
+ print(f"βœ… Total chunks: {len(all_texts)}")
786
+ print(f"βœ… Total images: {len(all_image_paths)}\n")
787
+
788
+ if len(all_texts) == 0:
789
+ print("❌ No content extracted!")
790
+ return None, []
791
+
792
+ # Build text index
793
+ print("πŸ”— Generating text embeddings...\n")
794
+
795
+ text_index = None
796
+ batch_size = 10
797
+
798
+ for i in range(0, len(all_texts), batch_size):
799
+ batch = all_texts[i:i+batch_size]
800
+ batch_contents = [doc.page_content for doc in batch]
801
+
802
+ try:
803
+ batch_embeddings = embedding_model.embed_documents(batch_contents, add_instruction=True)
804
+ batch_embeddings_np = np.array(batch_embeddings).astype('float32')
805
+
806
+ if text_index is None:
807
+ dimension = batch_embeddings_np.shape[1]
808
+ text_index = faiss.IndexFlatIP(dimension)
809
+ print(f" Text embedding dimension: {dimension}")
810
+
811
+ faiss.normalize_L2(batch_embeddings_np)
812
+ text_index.add(batch_embeddings_np)
813
+
814
+ if (i // batch_size + 1) % 5 == 0:
815
+ print(f" Progress: {i + len(batch)}/{len(all_texts)}")
816
+
817
+ except Exception as e:
818
+ print(f" ❌ Error: {e}")
819
+ raise
820
+
821
+ print(f" βœ… Complete")
822
+
823
+ # Save FAISS index
824
+ faiss.write_index(text_index, f"{FAISS_INDEX_PATH}/text_index.faiss")
825
+
826
+ # Save documents
827
+ with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "wb") as f:
828
+ pickle.dump(all_texts, f)
829
+
830
+ # βœ… BUILD AND SAVE BM25 INDEX
831
+ print("\nπŸ” Building BM25 index for keyword search...")
832
+ tokenized_docs = [doc.page_content.lower().split() for doc in all_texts]
833
+ bm25_index = BM25Okapi(tokenized_docs,k1=1.3, b=0.65)
834
+
835
+ with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "wb") as f:
836
+ pickle.dump(bm25_index, f)
837
+
838
+ print(" βœ… BM25 index saved")
839
+
840
+ # Build image index
841
+ if len(all_image_paths) > 0:
842
+ print(f"\nπŸ–ΌοΈ Embedding images...")
843
+
844
+ image_index = None
845
+ successful_images = []
846
+
847
+ for idx, img_data in enumerate(all_image_paths):
848
+ img_embedding = embedding_model.embed_image(img_data["path"])
849
+
850
+ if img_embedding is None:
851
+ continue
852
+
853
+ img_embedding_np = np.array([img_embedding]).astype('float32')
854
+
855
+ if image_index is None:
856
+ dimension = img_embedding_np.shape[1]
857
+ image_index = faiss.IndexFlatIP(dimension)
858
+ print(f" Image dimension: {dimension}")
859
+
860
+ faiss.normalize_L2(img_embedding_np)
861
+ image_index.add(img_embedding_np)
862
+ successful_images.append(img_data)
863
+
864
+ if (len(successful_images)) % 10 == 0:
865
+ print(f" Progress: {len(successful_images)}/{len(all_image_paths)}")
866
+
867
+ print(f" βœ… {len(successful_images)} images embedded")
868
+
869
+ if image_index is not None and len(successful_images) > 0:
870
+ faiss.write_index(image_index, f"{FAISS_INDEX_PATH}/image_index.faiss")
871
+
872
+ with open(f"{FAISS_INDEX_PATH}/image_documents.pkl", "wb") as f:
873
+ pickle.dump(successful_images, f)
874
+
875
+ # Save hash
876
+ with open(index_hash_file, 'w') as f:
877
+ f.write(current_hash)
878
+
879
+ print(f"\nβœ… Index saved: {FAISS_INDEX_PATH}\n")
880
+
881
+ return text_index, all_texts
882
+
883
+
884
+ # ========================================
885
+ # πŸ” OPTIMIZED HYBRID SEARCH
886
+ # ========================================
887
+ # ========================================
888
+ # πŸ“Š QUERY WITH BM25 ONLY
889
+ # ========================================
890
+ def query_with_bm25(query: str, k_text: int = 5, k_images: int = 3):
891
+ """Query using BM25 keyword search only."""
892
+
893
+ # βœ… PREPROCESS QUERY
894
+ processed_query = preprocess_query(query)
895
+ print(f" πŸ” Processed: {processed_query}")
896
+
897
+ # Load documents
898
+ with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "rb") as f:
899
+ text_docs = pickle.load(f)
900
+
901
+ # βœ… LOAD BM25 INDEX
902
+ try:
903
+ with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "rb") as f:
904
+ bm25_index = pickle.load(f)
905
+ except FileNotFoundError:
906
+ print(" ⚠️ BM25 index not found, building on-the-fly...")
907
+ tokenized_docs = [doc.page_content.lower().split() for doc in text_docs]
908
+ bm25_index = BM25Okapi(tokenized_docs)
909
+
910
+ # BM25 SEARCH ONLY
911
+ tokenized_query = processed_query.lower().split()
912
+ bm25_scores = bm25_index.get_scores(tokenized_query)
913
+
914
+ # Get top k results
915
+ top_indices = np.argsort(bm25_scores)[::-1][:k_text]
916
+
917
+ text_results = []
918
+ relevant_pages = set()
919
+
920
+ for rank, idx in enumerate(top_indices, 1):
921
+ doc = text_docs[idx]
922
+ score = float(bm25_scores[idx])
923
+
924
+ text_results.append({
925
+ "document": doc,
926
+ "score": score,
927
+ "rank": rank,
928
+ "type": doc.metadata.get('type', 'text')
929
+ })
930
+ relevant_pages.add((doc.metadata.get('source'), doc.metadata.get('page')))
931
+
932
+ # Get images from relevant pages (not semantic search)
933
+ relevant_images = []
934
+
935
+ try:
936
+ image_docs_path = f"{FAISS_INDEX_PATH}/image_documents.pkl"
937
+
938
+ if os.path.exists(image_docs_path):
939
+ with open(image_docs_path, "rb") as f:
940
+ image_docs = pickle.load(f)
941
+
942
+ # Get images from same pages as top text results
943
+ for img_doc in image_docs:
944
+ img_page = (img_doc['source'], img_doc['page'])
945
+ if img_page in relevant_pages and len(relevant_images) < k_images:
946
+ relevant_images.append({
947
+ "path": img_doc['path'],
948
+ "source": img_doc['source'],
949
+ "page": img_doc['page'],
950
+ "type": "image",
951
+ "score": 0.0,
952
+ "rank": len(relevant_images) + 1,
953
+ "from_page": True
954
+ })
955
+
956
+ except Exception as e:
957
+ pass
958
+
959
+ return {
960
+ "text_results": text_results,
961
+ "images": relevant_images,
962
+ "query": query,
963
+ "processed_query": processed_query
964
+ }
965
+
966
+
967
+ # ========================================
968
+ # πŸ“Š DISPLAY RESULTS (BM25 ONLY)
969
+ # ========================================
970
+ def display_results_bm25(results: Dict):
971
+ """Display BM25 results."""
972
+
973
+ print("\nπŸ“š TOP RESULTS (BM25 Keyword Search):\n")
974
+
975
+ for result in results['text_results']:
976
+ doc = result["document"]
977
+ print(f"[{result['rank']}] BM25 Score: {result['score']:.4f} | {doc.metadata.get('type', 'N/A')}")
978
+ print(f" πŸ“„ {doc.metadata.get('source')} - Page {doc.metadata.get('page')}")
979
+ print(f" πŸ“Œ {doc.metadata.get('heading', 'N/A')[:60]}")
980
+
981
+ if 'total_chunks' in doc.metadata and doc.metadata.get('total_chunks', 1) > 1:
982
+ print(f" πŸ”— Chunk {doc.metadata.get('chunk_index', 0)+1}/{doc.metadata.get('total_chunks')}")
983
+
984
+ print(f" πŸ“ {doc.page_content[:200]}...")
985
+ print()
986
+
987
+ print("\nπŸ–ΌοΈ IMAGES:\n")
988
+ if results['images']:
989
+ for img in results['images']:
990
+ print(f"[{img['rank']}] {img['source']} - Page {img['page']}")
991
+ print(f" {img['path']}\n")
992
+ else:
993
+ print(" No images found\n")
994
+
995
+
996
+ # ========================================
997
+ # πŸ” HYBRID SEARCH IMPLEMENTATION
998
+ # ========================================
999
+
1000
+ def normalize_scores(scores: np.ndarray) -> np.ndarray:
1001
+ """Min-max normalization to 0-1 range."""
1002
+ if len(scores) == 0:
1003
+ return scores
1004
+
1005
+ min_score = np.min(scores)
1006
+ max_score = np.max(scores)
1007
+
1008
+ if max_score == min_score:
1009
+ return np.ones_like(scores)
1010
+
1011
+ return (scores - min_score) / (max_score - min_score)
1012
+
1013
+
1014
+ def query_with_hybrid(query: str, embedding_model: VLM2VecEmbeddings,
1015
+ k_text: int = 5, k_images: int = 3,
1016
+ dense_weight: float = DENSE_WEIGHT,
1017
+ sparse_weight: float = SPARSE_WEIGHT):
1018
+
1019
+ """
1020
+ Hybrid search combining semantic (FAISS) and keyword (BM25) retrieval.
1021
+ """
1022
+
1023
+ processed_query = preprocess_query(query)
1024
+ print(f" πŸ” Processed: {processed_query}")
1025
+
1026
+ with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "rb") as f:
1027
+ text_docs = pickle.load(f)
1028
+
1029
+ # SEMANTIC SEARCH
1030
+ print(f" 🧠 Running semantic search...")
1031
+
1032
+ try:
1033
+ text_index = faiss.read_index(f"{FAISS_INDEX_PATH}/text_index.faiss")
1034
+
1035
+ query_embedding = embedding_model.embed_query(processed_query)
1036
+ query_np = np.array([query_embedding]).astype('float32')
1037
+ faiss.normalize_L2(query_np)
1038
+
1039
+ k_retrieve = min(k_text * 3, len(text_docs))
1040
+ distances, indices = text_index.search(query_np, k_retrieve)
1041
+
1042
+ semantic_scores = distances[0]
1043
+ semantic_indices = indices[0]
1044
+
1045
+ print(f" βœ… Retrieved {len(semantic_indices)} semantic results")
1046
+
1047
+ except Exception as e:
1048
+ print(f" ⚠️ Semantic search failed: {e}")
1049
+ semantic_scores = np.array([])
1050
+ semantic_indices = np.array([])
1051
+
1052
+ # BM25 SEARCH
1053
+ print(f" πŸ”€ Running BM25 keyword search...")
1054
+
1055
+ try:
1056
+ with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "rb") as f:
1057
+ bm25_index = pickle.load(f)
1058
+ except FileNotFoundError:
1059
+ tokenized_docs = [doc.page_content.lower().split() for doc in text_docs]
1060
+ bm25_index = BM25Okapi(tokenized_docs, k1=1.3, b=0.65)
1061
+
1062
+ tokenized_query = processed_query.lower().split()
1063
+ bm25_scores_all = bm25_index.get_scores(tokenized_query)
1064
+
1065
+ print(f" βœ… Scored {len(bm25_scores_all)} documents")
1066
+
1067
+ # SCORE FUSION
1068
+ print(f" βš–οΈ Fusing scores (semantic: {dense_weight}, BM25: {sparse_weight})...")
1069
+
1070
+ combined_scores = {}
1071
+
1072
+ if len(semantic_scores) > 0:
1073
+ semantic_scores_norm = normalize_scores(semantic_scores)
1074
+
1075
+ for idx, score in zip(semantic_indices, semantic_scores_norm):
1076
+ if idx < len(text_docs):
1077
+ combined_scores[idx] = dense_weight * score
1078
+
1079
+ bm25_scores_norm = normalize_scores(bm25_scores_all)
1080
+
1081
+ for idx, score in enumerate(bm25_scores_norm):
1082
+ if idx in combined_scores:
1083
+ combined_scores[idx] += sparse_weight * score
1084
+ else:
1085
+ combined_scores[idx] = sparse_weight * score
1086
+
1087
+ sorted_indices = sorted(combined_scores.keys(),
1088
+ key=lambda x: combined_scores[x],
1089
+ reverse=True)
1090
+
1091
+ top_indices = sorted_indices[:k_text]
1092
+
1093
+ print(f" βœ… Top {len(top_indices)} results selected")
1094
+
1095
+ # PREPARE RESULTS
1096
+ text_results = []
1097
+ relevant_pages = set()
1098
+
1099
+ for rank, idx in enumerate(top_indices, 1):
1100
+ doc = text_docs[idx]
1101
+
1102
+ semantic_score = semantic_scores_norm[np.where(semantic_indices == idx)[0][0]] if idx in semantic_indices else 0.0
1103
+ bm25_score = bm25_scores_norm[idx]
1104
+ combined_score = combined_scores[idx]
1105
+
1106
+ text_results.append({
1107
+ "document": doc,
1108
+ "score": combined_score,
1109
+ "semantic_score": float(semantic_score),
1110
+ "bm25_score": float(bm25_score),
1111
+ "rank": rank,
1112
+ "type": doc.metadata.get('type', 'text')
1113
+ })
1114
+ relevant_pages.add((doc.metadata.get('source'), doc.metadata.get('page')))
1115
+
1116
+ # GET IMAGES
1117
+ relevant_images = []
1118
+
1119
+ try:
1120
+ image_docs_path = f"{FAISS_INDEX_PATH}/image_documents.pkl"
1121
+
1122
+ if os.path.exists(image_docs_path):
1123
+ with open(image_docs_path, "rb") as f:
1124
+ image_docs = pickle.load(f)
1125
+
1126
+ for img_doc in image_docs:
1127
+ img_page = (img_doc['source'], img_doc['page'])
1128
+ if img_page in relevant_pages and len(relevant_images) < k_images:
1129
+ relevant_images.append({
1130
+ "path": img_doc['path'],
1131
+ "source": img_doc['source'],
1132
+ "page": img_doc['page'],
1133
+ "type": "image",
1134
+ "score": 0.0,
1135
+ "rank": len(relevant_images) + 1,
1136
+ "from_page": True
1137
+ })
1138
+ except Exception as e:
1139
+ pass
1140
+
1141
+ return {
1142
+ "text_results": text_results,
1143
+ "images": relevant_images,
1144
+ "query": query,
1145
+ "processed_query": processed_query,
1146
+ "method": "hybrid"
1147
+ }
1148
+
1149
+
1150
+ def display_results_hybrid(results: Dict):
1151
+ """Display hybrid search results."""
1152
+
1153
+ print("\nπŸ“š TOP RESULTS (Hybrid Search: Semantic + BM25):\n")
1154
+
1155
+ for result in results['text_results']:
1156
+ doc = result["document"]
1157
+ print(f"[{result['rank']}] Combined: {result['score']:.4f} "
1158
+ f"(Semantic: {result['semantic_score']:.4f}, BM25: {result['bm25_score']:.4f}) "
1159
+ f"| {doc.metadata.get('type', 'N/A')}")
1160
+ print(f" πŸ“„ {doc.metadata.get('source')} - Page {doc.metadata.get('page')}")
1161
+ print(f" πŸ“Œ {doc.metadata.get('heading', 'N/A')[:60]}")
1162
+
1163
+ if 'total_chunks' in doc.metadata and doc.metadata.get('total_chunks', 1) > 1:
1164
+ print(f" πŸ”— Chunk {doc.metadata.get('chunk_index', 0)+1}/{doc.metadata.get('total_chunks')}")
1165
+
1166
+ print(f" πŸ“ {doc.page_content[:200]}...")
1167
+ print()
1168
+
1169
+ print("\nπŸ–ΌοΈ IMAGES:\n")
1170
+ if results['images']:
1171
+ for img in results['images']:
1172
+ print(f"[{img['rank']}] {img['source']} - Page {img['page']}")
1173
+ print(f" {img['path']}\n")
1174
+ else:
1175
+ print(" No images found\n")
1176
+
1177
+
1178
+ # ========================================
1179
+ # πŸ“– GET CONTEXT WITH PARENTS
1180
+ # ========================================
1181
+ def get_context_with_parents(results: Dict) -> List[Dict]:
1182
+ """Extract full parent contexts."""
1183
+
1184
+ seen_parents = set()
1185
+ contexts = []
1186
+
1187
+ for result in results['text_results']:
1188
+ doc = result['document']
1189
+ parent = doc.metadata.get('parent_text')
1190
+
1191
+ if parent and parent not in seen_parents:
1192
+ contexts.append({
1193
+ "text": parent,
1194
+ "source": doc.metadata['source'],
1195
+ "page": doc.metadata['page'],
1196
+ "heading": doc.metadata['heading'],
1197
+ "type": doc.metadata.get('type', 'text'),
1198
+ "is_parent": True
1199
+ })
1200
+ seen_parents.add(parent)
1201
+ elif not parent:
1202
+ contexts.append({
1203
+ "text": doc.page_content,
1204
+ "source": doc.metadata['source'],
1205
+ "page": doc.metadata['page'],
1206
+ "heading": doc.metadata['heading'],
1207
+ "type": doc.metadata.get('type', 'text'),
1208
+ "is_parent": False
1209
+ })
1210
+
1211
+ return contexts
1212
+
1213
+
1214
+ # ========================================
1215
+ # πŸš€ MAIN EXECUTION (UPDATED FOR HYBRID)
1216
+ # ========================================
1217
+ if __name__ == "__main__":
1218
+ print("="*70)
1219
+ print("πŸš€ RAG with HYBRID SEARCH (Semantic + BM25)")
1220
+ print("="*70 + "\n")
1221
+
1222
+ pdf_files = glob.glob(f"{PDF_DIR}/*.pdf")
1223
+ print(f"πŸ“‚ Found {len(pdf_files)} PDF files\n")
1224
+
1225
+ if len(pdf_files) == 0:
1226
+ print("❌ No PDFs found!")
1227
+ exit(1)
1228
+
1229
+ print("\nπŸ€– Loading VLM2Vec model...")
1230
+ embedding_model = VLM2VecEmbeddings(
1231
+ model_name="TIGER-Lab/VLM2Vec-Qwen2VL-2B",
1232
+ cache_dir=MODEL_CACHE_DIR
1233
+ )
1234
+
1235
+ # Load or build index
1236
+ if os.path.exists(f"{FAISS_INDEX_PATH}/text_index.faiss"):
1237
+ print(f"βœ… Loading existing index\n")
1238
+
1239
+ if not os.path.exists(f"{FAISS_INDEX_PATH}/bm25_index.pkl"):
1240
+ print("⚠️ BM25 index missing, building now...")
1241
+
1242
+ with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "rb") as f:
1243
+ all_texts = pickle.load(f)
1244
+
1245
+ print(" Building BM25 index...")
1246
+ tokenized_docs = [doc.page_content.lower().split() for doc in all_texts]
1247
+ bm25_index = BM25Okapi(tokenized_docs, k1=1.3, b=0.65)
1248
+
1249
+ with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "wb") as f:
1250
+ pickle.dump(bm25_index, f)
1251
+
1252
+ print(" βœ… BM25 index saved\n")
1253
+
1254
+ else:
1255
+ print("πŸ”¨ Building new index...\n")
1256
+
1257
+ embedding_model = VLM2VecEmbeddings(
1258
+ model_name="TIGER-Lab/VLM2Vec-Qwen2VL-2B",
1259
+ cache_dir=MODEL_CACHE_DIR
1260
+ )
1261
+
1262
+ index, documents = build_multimodal_faiss_streaming(pdf_files, embedding_model)
1263
+
1264
+ if index is None:
1265
+ exit(0)
1266
+
1267
+ # Interactive testing
1268
+ print("="*70)
1269
+ print("πŸ§ͺ TESTING MODE - HYBRID SEARCH")
1270
+ print(f" Weights: Semantic {DENSE_WEIGHT} | BM25 {SPARSE_WEIGHT}")
1271
+ print("="*70 + "\n")
1272
+
1273
+ test_queries = [
1274
+ "What is the higher and lower explosive limit of butane?",
1275
+ "What are the precautions taken while handling H2S?",
1276
+ "What are the Personal Protection used for Sulfolane?",
1277
+ "What is the Composition of Platforming Feed and Product?",
1278
+ "Explain Dual function platforming catalyst chemistry.",
1279
+ "Steps to be followed in Amine Regeneration Unit for normal shutdown process.",
1280
+ "Could you tell me what De-greasing of Amine System in pre startup wash",
1281
+ ]
1282
+
1283
+ print("πŸ“‹ SUGGESTED QUERIES:")
1284
+ for i, q in enumerate(test_queries, 1):
1285
+ print(f" {i}. {q}")
1286
+ print()
1287
+ print("πŸ’‘ Type 'mode' to switch between hybrid/bm25/semantic")
1288
+ print()
1289
+
1290
+ current_mode = "hybrid"
1291
+
1292
+ while True:
1293
+ user_query = input(f"πŸ’¬ Query [{current_mode}] (or 1-5, 'mode', or 'exit'): ").strip()
1294
+
1295
+ if user_query.lower() == 'exit':
1296
+ print("\nβœ… Done!")
1297
+ break
1298
+
1299
+ if user_query.lower() == 'mode':
1300
+ print("\nπŸ”„ Select mode:")
1301
+ print(" 1. Hybrid (Semantic + BM25)")
1302
+ print(" 2. BM25 only")
1303
+ print(" 3. Semantic only")
1304
+ mode_choice = input(" Choice (1-3): ").strip()
1305
+
1306
+ if mode_choice == '1':
1307
+ current_mode = "hybrid"
1308
+ elif mode_choice == '2':
1309
+ current_mode = "bm25"
1310
+ elif mode_choice == '3':
1311
+ current_mode = "semantic"
1312
+
1313
+ print(f" βœ… Mode set to: {current_mode}\n")
1314
+ continue
1315
+
1316
+ if user_query.isdigit() and 1 <= int(user_query) <= len(test_queries):
1317
+ user_query = test_queries[int(user_query) - 1]
1318
+
1319
+ if not user_query:
1320
+ continue
1321
+
1322
+ print(f"\n{'='*60}")
1323
+ print(f"πŸ” Query: {user_query}")
1324
+ print(f"πŸ”§ Mode: {current_mode.upper()}")
1325
+ print(f"{'='*60}\n")
1326
+
1327
+ try:
1328
+ if current_mode == "hybrid":
1329
+ results = query_with_hybrid(user_query, embedding_model, k_text=5, k_images=3)
1330
+ display_results_hybrid(results)
1331
+ elif current_mode == "bm25":
1332
+ results = query_with_bm25(user_query, k_text=5, k_images=3)
1333
+ display_results_bm25(results)
1334
+ else: # semantic only
1335
+ results = query_with_hybrid(user_query, embedding_model, k_text=5, k_images=3,
1336
+ dense_weight=1.0, sparse_weight=0.0)
1337
+ display_results_hybrid(results)
1338
+
1339
+ print("\nπŸ“– FULL CONTEXT:\n")
1340
+ contexts = get_context_with_parents(results)
1341
+
1342
+ for i, ctx in enumerate(contexts[:3], 1):
1343
+ print(f"[{i}] {ctx['heading'][:50]}")
1344
+ if ctx['is_parent']:
1345
+ print(f" βœ… Full section")
1346
+ print(f" {ctx['text'][:300]}...\n")
1347
+
1348
+ print("="*60 + "\n")
1349
+
1350
+ except Exception as e:
1351
+ print(f"\n❌ Error: {e}\n")
1352
+ import traceback
1353
+ traceback.print_exc()
1354
+