BeRU Deployer commited on
Commit Β·
dec533d
1
Parent(s): 8357835
Deploy BeRU Streamlit RAG System - Add app, models logic, configs, and optimizations for HF Spaces
Browse files- .hfignore +7 -0
- .streamlit/config.toml +21 -0
- Dockerfile +49 -0
- README.md +47 -6
- app.py +282 -0
- down.py +898 -0
- frontend.html +1075 -0
- requirements.txt +19 -0
- spaces_app.py +229 -0
- vlm2rag2.py +1354 -0
.hfignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ignore local model and index directories when pushing to Hugging Face
|
| 2 |
+
models/
|
| 3 |
+
VLM2Vec-V2rag3/
|
| 4 |
+
faiss_index/
|
| 5 |
+
venv*
|
| 6 |
+
*.log
|
| 7 |
+
rag.log
|
.streamlit/config.toml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[global]
|
| 2 |
+
# Page customization
|
| 3 |
+
analyticsEnabled = false
|
| 4 |
+
logLevel = "info"
|
| 5 |
+
|
| 6 |
+
[client]
|
| 7 |
+
# Faster loading
|
| 8 |
+
showErrorDetails = true
|
| 9 |
+
toolbarMode = "minimal"
|
| 10 |
+
|
| 11 |
+
[server]
|
| 12 |
+
# HF Spaces optimizations
|
| 13 |
+
port = 7860
|
| 14 |
+
headless = true
|
| 15 |
+
runOnSave = false
|
| 16 |
+
maxUploadSize = 200
|
| 17 |
+
enableCORS = false
|
| 18 |
+
enableXsrfProtection = true
|
| 19 |
+
|
| 20 |
+
# Memory management
|
| 21 |
+
maxCachedMessageSize = 2
|
Dockerfile
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies (for PDF processing, OCR, etc.)
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
build-essential \
|
| 8 |
+
git \
|
| 9 |
+
libpoppler-cpp-dev \
|
| 10 |
+
poppler-utils \
|
| 11 |
+
tesseract-ocr \
|
| 12 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 13 |
+
|
| 14 |
+
# Copy requirements
|
| 15 |
+
COPY requirements.txt .
|
| 16 |
+
|
| 17 |
+
# Install Python dependencies
|
| 18 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 19 |
+
|
| 20 |
+
# Copy application code
|
| 21 |
+
COPY down.py .
|
| 22 |
+
COPY frontend.html .
|
| 23 |
+
COPY vlm2rag2.py .
|
| 24 |
+
COPY check_user.py .
|
| 25 |
+
# include the Streamlit demo as an alternative entrypoint
|
| 26 |
+
COPY app.py .
|
| 27 |
+
|
| 28 |
+
# Create necessary directories
|
| 29 |
+
RUN mkdir -p /app/.cache /app/models /app/VLM2Vec-V2rag3
|
| 30 |
+
|
| 31 |
+
# Expose HF Spaces default port
|
| 32 |
+
EXPOSE 7860
|
| 33 |
+
|
| 34 |
+
# by default the image will run the FastAPI server; to start the Streamlit UI
|
| 35 |
+
to test locally you can override the command:
|
| 36 |
+
# docker run -p7860:7860 <image> streamlit run app.py --server.port=7860
|
| 37 |
+
|
| 38 |
+
# Set environment variables
|
| 39 |
+
ENV PYTHONUNBUFFERED=1
|
| 40 |
+
ENV HF_HUB_DISABLE_SYMLINKS_WARNING=1
|
| 41 |
+
|
| 42 |
+
# default paths used by down.py (can be overridden at runtime)
|
| 43 |
+
ENV MODEL_DIR=/app/models
|
| 44 |
+
ENV LLM_MODEL_PATH=/app/models/Mistral-7B-Instruct-v0.3
|
| 45 |
+
ENV EMBED_MODEL_PATH=/app/models/VLM2Vec-Qwen2VL-2B
|
| 46 |
+
ENV FAISS_INDEX_PATH=/app/VLM2Vec-V2rag3
|
| 47 |
+
|
| 48 |
+
# Run FastAPI app on port 7860
|
| 49 |
+
CMD ["python", "down.py", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,13 +1,54 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: yellow
|
| 6 |
-
sdk:
|
| 7 |
-
sdk_version: 6.9.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
short_description:
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: BeRU Chat - RAG Assistant
|
| 3 |
+
emoji: π€
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: yellow
|
| 6 |
+
sdk: streamlit
|
|
|
|
| 7 |
app_file: app.py
|
| 8 |
pinned: false
|
| 9 |
+
short_description: 100% Offline RAG System with Mistral 7B and VLM2Vec
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# π€ BeRU Chat - RAG Assistant
|
| 13 |
+
|
| 14 |
+
A powerful **100% offline Retrieval-Augmented Generation (RAG) system** combining Mistral 7B LLM with VLM2Vec embeddings for intelligent document search and conversation.
|
| 15 |
+
|
| 16 |
+
## β¨ Features
|
| 17 |
+
|
| 18 |
+
- π **100% Offline Operation** - No internet required after startup
|
| 19 |
+
- π§ **Advanced RAG Architecture**
|
| 20 |
+
- Hybrid retrieval (Vector + BM25 keyword search)
|
| 21 |
+
- Ensemble retriever combining multiple strategies
|
| 22 |
+
- Re-ranking with FlashRank for relevance
|
| 23 |
+
- Multi-turn conversation with history awareness
|
| 24 |
+
- β‘ **Optimized Performance**
|
| 25 |
+
- 4-bit quantization with BitsAndBytes
|
| 26 |
+
- Flash Attention 2 support
|
| 27 |
+
- FAISS vector indexing
|
| 28 |
+
- π **Source Citations** - Every answer cites original sources
|
| 29 |
+
|
| 30 |
+
## π― Models Used
|
| 31 |
+
|
| 32 |
+
| Component | Model | Details |
|
| 33 |
+
|-----------|-------|---------|
|
| 34 |
+
| **LLM** | Mistral-7B-Instruct-v0.3 | 7B parameters |
|
| 35 |
+
| **Embedding** | VLM2Vec-Qwen2VL-2B | 2B parameters |
|
| 36 |
+
| **Vector Store** | FAISS | Meta's similarity search |
|
| 37 |
+
|
| 38 |
+
## π Getting Started
|
| 39 |
+
|
| 40 |
+
1. **Wait for Models** - First load takes 5-8 minutes (models download from HF Hub)
|
| 41 |
+
2. **Upload Documents** - Add PDFs or text files for RAG
|
| 42 |
+
3. **Ask Questions** - Chat with context-aware answers
|
| 43 |
+
4. **Get Sources** - Each answer includes citations
|
| 44 |
+
|
| 45 |
+
## π» System Requirements
|
| 46 |
+
|
| 47 |
+
- **GPU**: A10G (24GB VRAM) recommended
|
| 48 |
+
- **RAM**: 16GB minimum
|
| 49 |
+
- **Cold Start**: ~5-8 minutes (first time)
|
| 50 |
+
- **Runtime**: Streamlit app on port 7860
|
| 51 |
+
|
| 52 |
+
## π Documentation
|
| 53 |
+
|
| 54 |
+
For more information, visit the [GitHub repository](https://github.com/AnwinJosy/BeRU)
|
app.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
import pickle
|
| 5 |
+
import faiss
|
| 6 |
+
import numpy as np
|
| 7 |
+
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
| 8 |
+
from typing import List, Dict
|
| 9 |
+
import time
|
| 10 |
+
|
| 11 |
+
# ========================================
|
| 12 |
+
# π¨ STREAMLIT PAGE CONFIG
|
| 13 |
+
# ========================================
|
| 14 |
+
st.set_page_config(
|
| 15 |
+
page_title="BeRU Chat - RAG Assistant",
|
| 16 |
+
page_icon="π€",
|
| 17 |
+
layout="wide",
|
| 18 |
+
initial_sidebar_state="expanded"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# ========================================
|
| 22 |
+
# π― CACHING FOR MODEL LOADING
|
| 23 |
+
# ========================================
|
| 24 |
+
@st.cache_resource
|
| 25 |
+
def load_embedding_model():
|
| 26 |
+
"""Load VLM2Vec embedding model"""
|
| 27 |
+
st.write("β³ Loading embedding model...")
|
| 28 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
+
|
| 30 |
+
model = AutoModel.from_pretrained(
|
| 31 |
+
"TIGER-Lab/VLM2Vec-Qwen2VL-2B",
|
| 32 |
+
trust_remote_code=True,
|
| 33 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
| 34 |
+
).to(device)
|
| 35 |
+
|
| 36 |
+
processor = AutoProcessor.from_pretrained(
|
| 37 |
+
"TIGER-Lab/VLM2Vec-Qwen2VL-2B",
|
| 38 |
+
trust_remote_code=True
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 42 |
+
"TIGER-Lab/VLM2Vec-Qwen2VL-2B",
|
| 43 |
+
trust_remote_code=True
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
model.eval()
|
| 47 |
+
st.success("β
Embedding model loaded!")
|
| 48 |
+
return model, processor, tokenizer, device
|
| 49 |
+
|
| 50 |
+
@st.cache_resource
|
| 51 |
+
def load_llm_model():
|
| 52 |
+
"""Load Mistral 7B LLM"""
|
| 53 |
+
st.write("β³ Loading language model...")
|
| 54 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 55 |
+
|
| 56 |
+
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
| 57 |
+
|
| 58 |
+
# 4-bit quantization config for efficiency
|
| 59 |
+
quantization_config = BitsAndBytesConfig(
|
| 60 |
+
load_in_4bit=True,
|
| 61 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 62 |
+
bnb_4bit_use_double_quant=True,
|
| 63 |
+
bnb_4bit_quant_type="nf4"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 67 |
+
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 68 |
+
quantization_config=quantization_config,
|
| 69 |
+
device_map="auto"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 73 |
+
"mistralai/Mistral-7B-Instruct-v0.3"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
st.success("β
Language model loaded!")
|
| 77 |
+
return model, tokenizer, device
|
| 78 |
+
|
| 79 |
+
@st.cache_resource
|
| 80 |
+
def load_faiss_index():
|
| 81 |
+
"""Load FAISS index if exists"""
|
| 82 |
+
if os.path.exists("VLM2Vec-V2rag2/text_index.faiss"):
|
| 83 |
+
st.write("β³ Loading FAISS index...")
|
| 84 |
+
index = faiss.read_index("VLM2Vec-V2rag2/text_index.faiss")
|
| 85 |
+
st.success("β
FAISS index loaded!")
|
| 86 |
+
return index
|
| 87 |
+
else:
|
| 88 |
+
st.warning("β οΈ FAISS index not found. Please build the index first.")
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
# ========================================
|
| 92 |
+
# π¬ EMBEDDING & RETRIEVAL FUNCTIONS
|
| 93 |
+
# ========================================
|
| 94 |
+
def get_embeddings(texts: List[str], model, processor, tokenizer, device) -> np.ndarray:
|
| 95 |
+
"""Generate embeddings for texts"""
|
| 96 |
+
embeddings_list = []
|
| 97 |
+
|
| 98 |
+
for text in texts:
|
| 99 |
+
inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device)
|
| 100 |
+
|
| 101 |
+
with torch.no_grad():
|
| 102 |
+
outputs = model(**inputs, output_hidden_states=True)
|
| 103 |
+
embedding = outputs.hidden_states[-1].mean(dim=1).cpu().numpy()
|
| 104 |
+
|
| 105 |
+
embeddings_list.append(embedding.flatten())
|
| 106 |
+
|
| 107 |
+
return np.array(embeddings_list)
|
| 108 |
+
|
| 109 |
+
def retrieve_documents(query: str, model, processor, tokenizer, device, faiss_index, k: int = 5) -> List[Dict]:
|
| 110 |
+
"""Retrieve relevant documents using FAISS"""
|
| 111 |
+
if faiss_index is None:
|
| 112 |
+
return []
|
| 113 |
+
|
| 114 |
+
# Get query embedding
|
| 115 |
+
query_embedding = get_embeddings([query], model, processor, tokenizer, device)
|
| 116 |
+
|
| 117 |
+
# Search FAISS index
|
| 118 |
+
distances, indices = faiss_index.search(query_embedding, k)
|
| 119 |
+
|
| 120 |
+
# Load documents metadata (assuming you have this stored)
|
| 121 |
+
results = []
|
| 122 |
+
for idx in indices[0]:
|
| 123 |
+
if idx >= 0:
|
| 124 |
+
results.append({
|
| 125 |
+
"index": idx,
|
| 126 |
+
"distance": float(distances[0][list(indices[0]).index(idx)])
|
| 127 |
+
})
|
| 128 |
+
|
| 129 |
+
return results
|
| 130 |
+
|
| 131 |
+
def generate_response(query: str, context: str, model, tokenizer, device) -> str:
|
| 132 |
+
"""Generate response using Mistral"""
|
| 133 |
+
|
| 134 |
+
prompt = f"""[INST] You are a helpful assistant answering questions about technical documentation.
|
| 135 |
+
|
| 136 |
+
Context:
|
| 137 |
+
{context}
|
| 138 |
+
|
| 139 |
+
Question: {query} [/INST]"""
|
| 140 |
+
|
| 141 |
+
inputs = tokenizer(prompt, return_tensors="pt", max_length=2048, truncation=True).to(device)
|
| 142 |
+
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
outputs = model.generate(
|
| 145 |
+
**inputs,
|
| 146 |
+
max_new_tokens=512,
|
| 147 |
+
temperature=0.7,
|
| 148 |
+
top_p=0.95,
|
| 149 |
+
do_sample=True
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 153 |
+
return response.split("[/INST]")[1].strip() if "[/INST]" in response else response
|
| 154 |
+
|
| 155 |
+
# ========================================
|
| 156 |
+
# π¨ STREAMLIT UI
|
| 157 |
+
# ========================================
|
| 158 |
+
st.title("π€ BeRU Chat Assistant")
|
| 159 |
+
st.markdown("*100% Offline RAG System with Mistral 7B & VLM2Vec*")
|
| 160 |
+
|
| 161 |
+
# Sidebar Configuration
|
| 162 |
+
with st.sidebar:
|
| 163 |
+
st.header("βοΈ Configuration")
|
| 164 |
+
|
| 165 |
+
device_info = "π’ GPU" if torch.cuda.is_available() else "π΄ CPU"
|
| 166 |
+
st.metric("Device", device_info)
|
| 167 |
+
|
| 168 |
+
num_results = st.slider("Retrieve top K documents", 1, 10, 5)
|
| 169 |
+
temperature = st.slider("Response Temperature", 0.1, 1.0, 0.7)
|
| 170 |
+
|
| 171 |
+
st.divider()
|
| 172 |
+
st.markdown("### π Project Info")
|
| 173 |
+
st.markdown("""
|
| 174 |
+
- **Model**: Mistral 7B Instruct v0.3
|
| 175 |
+
- **Embeddings**: VLM2Vec-Qwen2VL-2B
|
| 176 |
+
- **Vector Store**: FAISS with 10K+ documents
|
| 177 |
+
- **Retrieval**: Hybrid (Dense + BM25)
|
| 178 |
+
""")
|
| 179 |
+
|
| 180 |
+
# Main Chat Interface
|
| 181 |
+
col1, col2 = st.columns([3, 1])
|
| 182 |
+
|
| 183 |
+
with col1:
|
| 184 |
+
st.subheader("π¬ Ask a Question")
|
| 185 |
+
|
| 186 |
+
with col2:
|
| 187 |
+
if st.button("π Clear Chat", use_container_width=True):
|
| 188 |
+
st.session_state.messages = []
|
| 189 |
+
st.rerun()
|
| 190 |
+
|
| 191 |
+
# Initialize session state
|
| 192 |
+
if "messages" not in st.session_state:
|
| 193 |
+
st.session_state.messages = []
|
| 194 |
+
|
| 195 |
+
if "models_loaded" not in st.session_state:
|
| 196 |
+
st.session_state.models_loaded = False
|
| 197 |
+
|
| 198 |
+
# Load models
|
| 199 |
+
if not st.session_state.models_loaded:
|
| 200 |
+
st.info("π¦ Loading models on first run... This may take 2-3 minutes.")
|
| 201 |
+
|
| 202 |
+
try:
|
| 203 |
+
embed_model, processor, tokenizer_embed, embed_device = load_embedding_model()
|
| 204 |
+
llm_model, tokenizer_llm, llm_device = load_llm_model()
|
| 205 |
+
faiss_idx = load_faiss_index()
|
| 206 |
+
|
| 207 |
+
st.session_state.embed_model = embed_model
|
| 208 |
+
st.session_state.processor = processor
|
| 209 |
+
st.session_state.tokenizer_embed = tokenizer_embed
|
| 210 |
+
st.session_state.embed_device = embed_device
|
| 211 |
+
st.session_state.llm_model = llm_model
|
| 212 |
+
st.session_state.tokenizer_llm = tokenizer_llm
|
| 213 |
+
st.session_state.llm_device = llm_device
|
| 214 |
+
st.session_state.faiss_idx = faiss_idx
|
| 215 |
+
st.session_state.models_loaded = True
|
| 216 |
+
st.success("β
All models loaded successfully!")
|
| 217 |
+
|
| 218 |
+
except Exception as e:
|
| 219 |
+
st.error(f"β Error loading models: {str(e)}")
|
| 220 |
+
st.stop()
|
| 221 |
+
|
| 222 |
+
# Chat Interface
|
| 223 |
+
st.markdown("---")
|
| 224 |
+
|
| 225 |
+
# Display chat history
|
| 226 |
+
for message in st.session_state.messages:
|
| 227 |
+
with st.chat_message(message["role"]):
|
| 228 |
+
st.markdown(message["content"])
|
| 229 |
+
|
| 230 |
+
# User input
|
| 231 |
+
user_input = st.chat_input("Type your question here...", key="user_input")
|
| 232 |
+
|
| 233 |
+
if user_input:
|
| 234 |
+
# Add user message to chat
|
| 235 |
+
st.session_state.messages.append({"role": "user", "content": user_input})
|
| 236 |
+
|
| 237 |
+
with st.chat_message("user"):
|
| 238 |
+
st.markdown(user_input)
|
| 239 |
+
|
| 240 |
+
# Generate response
|
| 241 |
+
with st.chat_message("assistant"):
|
| 242 |
+
st.write("π Retrieving relevant documents...")
|
| 243 |
+
|
| 244 |
+
# Retrieve documents
|
| 245 |
+
retrieved = retrieve_documents(
|
| 246 |
+
user_input,
|
| 247 |
+
st.session_state.embed_model,
|
| 248 |
+
st.session_state.processor,
|
| 249 |
+
st.session_state.tokenizer_embed,
|
| 250 |
+
st.session_state.embed_device,
|
| 251 |
+
st.session_state.faiss_idx,
|
| 252 |
+
k=num_results
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
context = "\n\n".join([f"Document {i+1}: Context from index {doc['index']}"
|
| 256 |
+
for i, doc in enumerate(retrieved)])
|
| 257 |
+
|
| 258 |
+
st.write("π Generating response...")
|
| 259 |
+
|
| 260 |
+
# Generate response
|
| 261 |
+
response = generate_response(
|
| 262 |
+
user_input,
|
| 263 |
+
context,
|
| 264 |
+
st.session_state.llm_model,
|
| 265 |
+
st.session_state.tokenizer_llm,
|
| 266 |
+
st.session_state.llm_device
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
st.markdown(response)
|
| 270 |
+
|
| 271 |
+
# Add to chat history
|
| 272 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
| 273 |
+
|
| 274 |
+
# Footer
|
| 275 |
+
st.markdown("---")
|
| 276 |
+
st.markdown("""
|
| 277 |
+
<div style='text-align: center; color: gray; font-size: 12px;'>
|
| 278 |
+
<p>BeRU Chat Assistant | Powered by Mistral 7B + VLM2Vec | 100% Offline</p>
|
| 279 |
+
<p><a href='https://github.com/AnwinJosy/BeRU'>GitHub</a> |
|
| 280 |
+
<a href='https://huggingface.co/AnwinJosy'>Hugging Face</a></p>
|
| 281 |
+
</div>
|
| 282 |
+
""", unsafe_allow_html=True)
|
down.py
ADDED
|
@@ -0,0 +1,898 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import logging
|
| 4 |
+
import asyncio
|
| 5 |
+
import re
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import List, Dict, Optional, Any
|
| 8 |
+
from contextlib import asynccontextmanager
|
| 9 |
+
from logging.handlers import RotatingFileHandler
|
| 10 |
+
|
| 11 |
+
# --- LANGCHAIN IMPORTS ---
|
| 12 |
+
from langchain_community.vectorstores import FAISS
|
| 13 |
+
from langchain.chains import create_history_aware_retriever
|
| 14 |
+
from langchain.chains.retrieval import create_retrieval_chain
|
| 15 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 16 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 17 |
+
from langchain_community.llms import HuggingFacePipeline
|
| 18 |
+
from langchain_core.embeddings import Embeddings
|
| 19 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 20 |
+
from langchain_community.retrievers import BM25Retriever
|
| 21 |
+
from langchain.retrievers import EnsembleRetriever
|
| 22 |
+
from langchain.retrievers.multi_query import MultiQueryRetriever
|
| 23 |
+
from langchain_core.runnables import RunnablePassthrough
|
| 24 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 25 |
+
from operator import itemgetter
|
| 26 |
+
|
| 27 |
+
# --- RERANKING IMPORTS ---
|
| 28 |
+
# Ensure you have installed flashrank: pip install flashrank
|
| 29 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
| 30 |
+
from langchain_community.document_compressors import FlashrankRerank
|
| 31 |
+
|
| 32 |
+
# --- TRANSFORMERS IMPORTS ---
|
| 33 |
+
from transformers import (
|
| 34 |
+
AutoTokenizer,
|
| 35 |
+
AutoModelForCausalLM,
|
| 36 |
+
AutoModel,
|
| 37 |
+
pipeline,
|
| 38 |
+
BitsAndBytesConfig
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# --- FASTAPI IMPORTS ---
|
| 42 |
+
from fastapi import FastAPI
|
| 43 |
+
from fastapi.responses import HTMLResponse, JSONResponse
|
| 44 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 45 |
+
from pydantic import BaseModel, Field, field_validator
|
| 46 |
+
import uvicorn
|
| 47 |
+
import numpy as np
|
| 48 |
+
|
| 49 |
+
# -------------------------------------------------------------------------
|
| 50 |
+
# 1. Pydantic Patch (Crucial for offline serialization)
|
| 51 |
+
# -------------------------------------------------------------------------
|
| 52 |
+
def patch_pydantic_for_pickle():
|
| 53 |
+
try:
|
| 54 |
+
from pydantic.v1.main import BaseModel as PydanticV1BaseModel
|
| 55 |
+
original_setstate = PydanticV1BaseModel.__setstate__
|
| 56 |
+
|
| 57 |
+
def patched_setstate(self, state):
|
| 58 |
+
if '__fields_set__' not in state:
|
| 59 |
+
state['__fields_set__'] = set(state.get('__dict__', {}).keys())
|
| 60 |
+
if '__private_attribute_values__' not in state:
|
| 61 |
+
state['__private_attribute_values__'] = {}
|
| 62 |
+
try:
|
| 63 |
+
original_setstate(self, state)
|
| 64 |
+
except Exception as e:
|
| 65 |
+
object.__setattr__(self, '__dict__', state.get('__dict__', {}))
|
| 66 |
+
object.__setattr__(self, '__fields_set__', state.get('__fields_set__', set()))
|
| 67 |
+
object.__setattr__(self, '__private_attribute_values__', state.get('__private_attribute_values__', {}))
|
| 68 |
+
|
| 69 |
+
PydanticV1BaseModel.__setstate__ = patched_setstate
|
| 70 |
+
print("β
Pydantic v1 patched for pickle compatibility")
|
| 71 |
+
|
| 72 |
+
except ImportError:
|
| 73 |
+
try:
|
| 74 |
+
import pydantic.v1 as pydantic_v1
|
| 75 |
+
from pydantic.v1 import BaseModel
|
| 76 |
+
original_setstate = BaseModel.__setstate__
|
| 77 |
+
|
| 78 |
+
def patched_setstate(self, state):
|
| 79 |
+
if '__fields_set__' not in state:
|
| 80 |
+
state['__fields_set__'] = set(state.get('__dict__', {}).keys())
|
| 81 |
+
if '__private_attribute_values__' not in state:
|
| 82 |
+
state['__private_attribute_values__'] = {}
|
| 83 |
+
try:
|
| 84 |
+
original_setstate(self, state)
|
| 85 |
+
except:
|
| 86 |
+
object.__setattr__(self, '__dict__', state.get('__dict__', {}))
|
| 87 |
+
object.__setattr__(self, '__fields_set__', state.get('__fields_set__', set()))
|
| 88 |
+
|
| 89 |
+
BaseModel.__setstate__ = patched_setstate
|
| 90 |
+
print("β
Pydantic patched for pickle compatibility")
|
| 91 |
+
|
| 92 |
+
except Exception as e:
|
| 93 |
+
print(f"β οΈ Could not patch Pydantic: {e}")
|
| 94 |
+
|
| 95 |
+
patch_pydantic_for_pickle()
|
| 96 |
+
|
| 97 |
+
# -------------------------------------------------------------------------
|
| 98 |
+
# 2. Configuration & Paths (workspace-agnostic)
|
| 99 |
+
# -------------------------------------------------------------------------
|
| 100 |
+
# environment variables allow overrides when running in containers / Spaces
|
| 101 |
+
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
|
| 102 |
+
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
| 103 |
+
os.environ["HF_DATASETS_OFFLINE"] = "1"
|
| 104 |
+
os.environ["HF_HUB_OFFLINE"] = "1"
|
| 105 |
+
|
| 106 |
+
# base directory for application files inside a container
|
| 107 |
+
ROOT_DIR = Path(os.environ.get("APP_ROOT", "/app")).resolve()
|
| 108 |
+
|
| 109 |
+
# model and index locations can be provided via env; defaults point into /app
|
| 110 |
+
MODEL_DIR = Path(os.environ.get("MODEL_DIR", ROOT_DIR / "models"))
|
| 111 |
+
LLM_MODEL_PATH = Path(os.environ.get("LLM_MODEL_PATH", MODEL_DIR / "Mistral-7B-Instruct-v0.3"))
|
| 112 |
+
EMBED_MODEL_PATH = Path(os.environ.get("EMBED_MODEL_PATH", MODEL_DIR / "VLM2Vec-Qwen2VL-2B"))
|
| 113 |
+
FAISS_INDEX_PATH = Path(os.environ.get("FAISS_INDEX_PATH", ROOT_DIR / "VLM2Vec-V2rag3"))
|
| 114 |
+
|
| 115 |
+
# Increased timeout for reranking operations
|
| 116 |
+
GENERATION_TIMEOUT = 240
|
| 117 |
+
LLM_MODEL = str(LLM_MODEL_PATH)
|
| 118 |
+
EMBED_MODEL = str(EMBED_MODEL_PATH)
|
| 119 |
+
|
| 120 |
+
# Logging Setup
|
| 121 |
+
logger = logging.getLogger("rag_system")
|
| 122 |
+
handler = RotatingFileHandler("rag.log", maxBytes=10 * 1024 * 1024, backupCount=5)
|
| 123 |
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
| 124 |
+
handler.setFormatter(formatter)
|
| 125 |
+
logger.addHandler(handler)
|
| 126 |
+
logger.setLevel(logging.INFO)
|
| 127 |
+
|
| 128 |
+
# Global Variables
|
| 129 |
+
vectorstore = None
|
| 130 |
+
llm_pipeline = None
|
| 131 |
+
qa_chain = None
|
| 132 |
+
answer_cache: Dict[str, Dict] = {}
|
| 133 |
+
conversations: Dict[str, List[Dict]] = {}
|
| 134 |
+
|
| 135 |
+
# -------------------------------------------------------------------------
|
| 136 |
+
# 3. VLM2Vec Embedding Class (Preserved)
|
| 137 |
+
# -------------------------------------------------------------------------
|
| 138 |
+
class VLM2VecEmbeddings(Embeddings):
|
| 139 |
+
def __init__(self, model_path: str, device: str = "cpu"):
|
| 140 |
+
print(f"π Loading VLM2Vec model from: {model_path}")
|
| 141 |
+
|
| 142 |
+
self.device = device
|
| 143 |
+
self.model_path = model_path
|
| 144 |
+
|
| 145 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 146 |
+
model_path,
|
| 147 |
+
trust_remote_code=True,
|
| 148 |
+
local_files_only=True,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None:
|
| 152 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 153 |
+
|
| 154 |
+
device_map = "auto" if device == "cuda" else "cpu"
|
| 155 |
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 156 |
+
|
| 157 |
+
self.model = AutoModel.from_pretrained(
|
| 158 |
+
model_path,
|
| 159 |
+
trust_remote_code=True,
|
| 160 |
+
dtype=dtype,
|
| 161 |
+
device_map=device_map,
|
| 162 |
+
local_files_only=True,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
self.model.eval()
|
| 166 |
+
|
| 167 |
+
try:
|
| 168 |
+
self.model_device = next(self.model.parameters()).device
|
| 169 |
+
except:
|
| 170 |
+
self.model_device = torch.device("cuda" if device == "cuda" else "cpu")
|
| 171 |
+
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
test_input = self.tokenizer("test", return_tensors="pt", add_special_tokens=True)
|
| 174 |
+
test_input = {k: v.to(self.model_device) for k, v in test_input.items()}
|
| 175 |
+
out = self.model(**test_input, output_hidden_states=True)
|
| 176 |
+
self.embedding_dim = out.hidden_states[-1].shape[-1]
|
| 177 |
+
|
| 178 |
+
print(f"β
VLM2Vec loaded on {self.model_device} | dim={self.embedding_dim}\n")
|
| 179 |
+
|
| 180 |
+
def _normalize_text(self, text: str) -> str:
|
| 181 |
+
text = re.sub(r'\s+', ' ', text or "")
|
| 182 |
+
text = re.sub(r'Page \d+', '', text, flags=re.IGNORECASE)
|
| 183 |
+
return text.strip()
|
| 184 |
+
|
| 185 |
+
def _ensure_non_empty(self, text: str) -> str:
|
| 186 |
+
t = self._normalize_text(text)
|
| 187 |
+
return t if t else "[EMPTY]"
|
| 188 |
+
|
| 189 |
+
def _embed_single(self, text: str) -> List[float]:
|
| 190 |
+
try:
|
| 191 |
+
with torch.no_grad():
|
| 192 |
+
clean_text = self._ensure_non_empty(text)
|
| 193 |
+
|
| 194 |
+
inputs = self.tokenizer(
|
| 195 |
+
clean_text,
|
| 196 |
+
return_tensors="pt",
|
| 197 |
+
add_special_tokens=True,
|
| 198 |
+
padding=True,
|
| 199 |
+
truncation=True,
|
| 200 |
+
max_length=512
|
| 201 |
+
)
|
| 202 |
+
inputs = {k: v.to(self.model_device) for k, v in inputs.items()}
|
| 203 |
+
|
| 204 |
+
outputs = self.model(**inputs, output_hidden_states=True)
|
| 205 |
+
|
| 206 |
+
if hasattr(outputs, "hidden_states") and outputs.hidden_states is not None:
|
| 207 |
+
hidden_states = outputs.hidden_states[-1]
|
| 208 |
+
attention_mask = inputs["attention_mask"].unsqueeze(-1).float()
|
| 209 |
+
|
| 210 |
+
weighted = hidden_states * attention_mask
|
| 211 |
+
sum_embeddings = weighted.sum(dim=1)
|
| 212 |
+
sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9)
|
| 213 |
+
embedding = (sum_embeddings / sum_mask).squeeze(0)
|
| 214 |
+
else:
|
| 215 |
+
embedding = outputs.logits.mean(dim=1).squeeze(0)
|
| 216 |
+
|
| 217 |
+
return embedding.cpu().numpy().tolist()
|
| 218 |
+
|
| 219 |
+
except Exception as e:
|
| 220 |
+
logger.error(f"VLM2Vec embedding error: {e}")
|
| 221 |
+
return [0.0] * getattr(self, "embedding_dim", 1024)
|
| 222 |
+
|
| 223 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
| 224 |
+
return [self._embed_single(t) for t in texts]
|
| 225 |
+
|
| 226 |
+
def embed_query(self, text: str) -> List[float]:
|
| 227 |
+
return self._embed_single(text)
|
| 228 |
+
|
| 229 |
+
# -------------------------------------------------------------------------
|
| 230 |
+
# 4. Prompt Templates (CLEANER & STRICTER)
|
| 231 |
+
# -------------------------------------------------------------------------
|
| 232 |
+
PROMPT_TEMPLATES = {
|
| 233 |
+
"Short and Concise": """<s>[INST] Answer the question based ONLY on the following context. Keep the answer under 3 sentences.
|
| 234 |
+
|
| 235 |
+
Context:
|
| 236 |
+
{context}
|
| 237 |
+
|
| 238 |
+
Question:
|
| 239 |
+
{input} [/INST]""",
|
| 240 |
+
|
| 241 |
+
"Detailed": """<s>[INST] You are a helpful assistant. Answer the question using ONLY the following context. Provide a detailed summary (4-5 sentences).
|
| 242 |
+
|
| 243 |
+
Context:
|
| 244 |
+
{context}
|
| 245 |
+
|
| 246 |
+
Question:
|
| 247 |
+
{input} [/INST]""",
|
| 248 |
+
|
| 249 |
+
"Step-by-Step": """<s>[INST] Based on the context below, provide a step-by-step procedure to answer the question.
|
| 250 |
+
|
| 251 |
+
Context:
|
| 252 |
+
{context}
|
| 253 |
+
|
| 254 |
+
Question:
|
| 255 |
+
{input} [/INST]""",
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
def structure_answer(answer: str, style: str) -> str:
|
| 259 |
+
# 1. REMOVE "Enough thinking" and specific artifacts
|
| 260 |
+
artifacts = [
|
| 261 |
+
"Enough thinking",
|
| 262 |
+
"Note:",
|
| 263 |
+
"System:",
|
| 264 |
+
"User:",
|
| 265 |
+
"[/INST]",
|
| 266 |
+
"Here is the answer:",
|
| 267 |
+
"Answer:"
|
| 268 |
+
]
|
| 269 |
+
|
| 270 |
+
for artifact in artifacts:
|
| 271 |
+
if artifact in answer:
|
| 272 |
+
# If it's "Enough thinking", just delete the phrase
|
| 273 |
+
answer = answer.replace(artifact, "")
|
| 274 |
+
|
| 275 |
+
# 2. SPLIT at likely hallucination points
|
| 276 |
+
# If the model starts writing "Human:" or "Question:" again, STOP there.
|
| 277 |
+
stop_markers = ["Human:", "Question:", "User input:", "Context:"]
|
| 278 |
+
for marker in stop_markers:
|
| 279 |
+
if marker in answer:
|
| 280 |
+
answer = answer.split(marker)[0]
|
| 281 |
+
|
| 282 |
+
clean_answer = answer.strip()
|
| 283 |
+
|
| 284 |
+
# 3. Final Formatting
|
| 285 |
+
if style == "Short and Concise":
|
| 286 |
+
sentences = clean_answer.split('.')
|
| 287 |
+
clean_answer = ". ".join(sentences[:2]) + "."
|
| 288 |
+
|
| 289 |
+
return clean_answer
|
| 290 |
+
# -------------------------------------------------------------------------
|
| 291 |
+
# 5. Load System
|
| 292 |
+
# -------------------------------------------------------------------------
|
| 293 |
+
def load_system():
|
| 294 |
+
global vectorstore, llm_pipeline, qa_chain
|
| 295 |
+
|
| 296 |
+
if not os.path.exists(LLM_MODEL_PATH):
|
| 297 |
+
raise FileNotFoundError(f"LLM model not found at: {LLM_MODEL_PATH}")
|
| 298 |
+
if not os.path.exists(EMBED_MODEL_PATH):
|
| 299 |
+
raise FileNotFoundError(f"Embedding model not found at: {EMBED_MODEL_PATH}")
|
| 300 |
+
if not os.path.exists(FAISS_INDEX_PATH):
|
| 301 |
+
raise FileNotFoundError(
|
| 302 |
+
f"FAISS index not found at: {FAISS_INDEX_PATH}\n"
|
| 303 |
+
f"Please run the rebuild_faiss_index.py script first!"
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
print("\n" + "=" * 70)
|
| 307 |
+
print("π LOADING RAG SYSTEM: Mistral 7B + VLM2Vec + Reranking (OFFLINE)")
|
| 308 |
+
print("=" * 70 + "\n")
|
| 309 |
+
|
| 310 |
+
_load_vectorstore()
|
| 311 |
+
_load_llm()
|
| 312 |
+
_build_retrieval_chain()
|
| 313 |
+
|
| 314 |
+
print("β
RAG system ready (100% OFFLINE)!\n")
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def _load_embeddings():
|
| 318 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 319 |
+
embedding_model = VLM2VecEmbeddings(
|
| 320 |
+
model_path=EMBED_MODEL_PATH,
|
| 321 |
+
device=device,
|
| 322 |
+
)
|
| 323 |
+
return embedding_model
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _load_vectorstore():
|
| 327 |
+
global vectorstore
|
| 328 |
+
|
| 329 |
+
import faiss
|
| 330 |
+
import pickle
|
| 331 |
+
from langchain_community.docstore.in_memory import InMemoryDocstore
|
| 332 |
+
from langchain_core.documents import Document
|
| 333 |
+
|
| 334 |
+
print(f"π₯ Loading FAISS index from: {FAISS_INDEX_PATH}")
|
| 335 |
+
|
| 336 |
+
text_index_path = os.path.join(FAISS_INDEX_PATH, "text_index.faiss")
|
| 337 |
+
text_docs_path = os.path.join(FAISS_INDEX_PATH, "text_documents.pkl")
|
| 338 |
+
|
| 339 |
+
if not os.path.exists(text_index_path):
|
| 340 |
+
raise FileNotFoundError(f"text_index.faiss not found")
|
| 341 |
+
if not os.path.exists(text_docs_path):
|
| 342 |
+
raise FileNotFoundError(f"text_documents.pkl not found")
|
| 343 |
+
|
| 344 |
+
embedding_model = _load_embeddings()
|
| 345 |
+
|
| 346 |
+
try:
|
| 347 |
+
index = faiss.read_index(text_index_path)
|
| 348 |
+
print(f" π FAISS index loaded: {index.ntotal} vectors")
|
| 349 |
+
|
| 350 |
+
print(" π Loading documents...")
|
| 351 |
+
|
| 352 |
+
documents = None
|
| 353 |
+
|
| 354 |
+
# Robust loading mechanism
|
| 355 |
+
try:
|
| 356 |
+
import pickle5
|
| 357 |
+
with open(text_docs_path, 'rb') as f:
|
| 358 |
+
documents = pickle5.load(f)
|
| 359 |
+
print(" β
Loaded with pickle5")
|
| 360 |
+
except (ImportError, Exception) as e:
|
| 361 |
+
pass
|
| 362 |
+
|
| 363 |
+
if documents is None:
|
| 364 |
+
try:
|
| 365 |
+
with open(text_docs_path, 'rb') as f:
|
| 366 |
+
documents = pickle.load(f, encoding='latin1')
|
| 367 |
+
print(" β
Loaded with latin1 encoding")
|
| 368 |
+
except Exception as e:
|
| 369 |
+
pass
|
| 370 |
+
|
| 371 |
+
if documents is None:
|
| 372 |
+
try:
|
| 373 |
+
import dill
|
| 374 |
+
with open(text_docs_path, 'rb') as f:
|
| 375 |
+
documents = dill.load(f)
|
| 376 |
+
print(" β
Loaded with dill")
|
| 377 |
+
except Exception as e:
|
| 378 |
+
print(f" β οΈ dill failed: {e}")
|
| 379 |
+
raise RuntimeError("Could not load documents. Check pickle version.")
|
| 380 |
+
|
| 381 |
+
if isinstance(documents, list):
|
| 382 |
+
print(f" Loaded {len(documents)} documents")
|
| 383 |
+
|
| 384 |
+
reconstructed_docs = []
|
| 385 |
+
for doc in documents:
|
| 386 |
+
if isinstance(doc, Document):
|
| 387 |
+
reconstructed_docs.append(doc)
|
| 388 |
+
else:
|
| 389 |
+
try:
|
| 390 |
+
new_doc = Document(
|
| 391 |
+
page_content=doc.page_content if hasattr(doc, 'page_content') else str(doc),
|
| 392 |
+
metadata=doc.metadata if hasattr(doc, 'metadata') else {}
|
| 393 |
+
)
|
| 394 |
+
reconstructed_docs.append(new_doc)
|
| 395 |
+
except Exception as e:
|
| 396 |
+
print(f" β οΈ Could not reconstruct document: {e}")
|
| 397 |
+
|
| 398 |
+
documents = reconstructed_docs
|
| 399 |
+
docstore = InMemoryDocstore({str(i): doc for i, doc in enumerate(documents)})
|
| 400 |
+
index_to_docstore_id = {i: str(i) for i in range(len(documents))}
|
| 401 |
+
|
| 402 |
+
elif isinstance(documents, dict):
|
| 403 |
+
print(f" Loaded {len(documents)} documents (dict)")
|
| 404 |
+
docstore = InMemoryDocstore(documents)
|
| 405 |
+
index_to_docstore_id = {i: key for i, key in enumerate(documents.keys())}
|
| 406 |
+
|
| 407 |
+
else:
|
| 408 |
+
raise ValueError(f"Unexpected documents format: {type(documents)}")
|
| 409 |
+
|
| 410 |
+
vectorstore = FAISS(
|
| 411 |
+
embedding_function=embedding_model,
|
| 412 |
+
index=index,
|
| 413 |
+
docstore=docstore,
|
| 414 |
+
index_to_docstore_id=index_to_docstore_id
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
print(f" π Total vectors: {vectorstore.index.ntotal}")
|
| 418 |
+
print("β
FAISS vectorstore loaded\n")
|
| 419 |
+
|
| 420 |
+
except Exception as e:
|
| 421 |
+
print(f"β Error loading FAISS index: {e}")
|
| 422 |
+
import traceback
|
| 423 |
+
traceback.print_exc()
|
| 424 |
+
raise
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def _load_llm():
|
| 428 |
+
print(f"π€ Loading LLM from: {LLM_MODEL_PATH} (OFFLINE - SPEED OPTIMIZED)")
|
| 429 |
+
|
| 430 |
+
bnb_config = BitsAndBytesConfig(
|
| 431 |
+
load_in_4bit=True,
|
| 432 |
+
bnb_4bit_quant_type="nf4",
|
| 433 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 434 |
+
bnb_4bit_use_double_quant=True,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH, local_files_only=True)
|
| 438 |
+
if tokenizer.pad_token_id is None:
|
| 439 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 440 |
+
|
| 441 |
+
# CHECK FOR FLASH ATTENTION SUPPORT
|
| 442 |
+
# (Fall back to standard if not supported)
|
| 443 |
+
try:
|
| 444 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 445 |
+
LLM_MODEL_PATH,
|
| 446 |
+
quantization_config=bnb_config,
|
| 447 |
+
device_map="auto",
|
| 448 |
+
local_files_only=True,
|
| 449 |
+
attn_implementation="flash_attention_2" # <--- SPEED BOOST
|
| 450 |
+
)
|
| 451 |
+
print(" β‘ Flash Attention 2 Enabled!")
|
| 452 |
+
except:
|
| 453 |
+
print(" β οΈ Flash Attention 2 not supported. Using standard attention.")
|
| 454 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 455 |
+
LLM_MODEL_PATH,
|
| 456 |
+
quantization_config=bnb_config,
|
| 457 |
+
device_map="auto",
|
| 458 |
+
local_files_only=True,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
pipe = pipeline(
|
| 462 |
+
"text-generation",
|
| 463 |
+
model=model,
|
| 464 |
+
tokenizer=tokenizer,
|
| 465 |
+
max_new_tokens=512,
|
| 466 |
+
do_sample=True,
|
| 467 |
+
temperature=0.01,
|
| 468 |
+
top_p=0.95,
|
| 469 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 470 |
+
return_full_text=False # Stop repetition
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
global llm_pipeline
|
| 474 |
+
llm_pipeline = HuggingFacePipeline(pipeline=pipe)
|
| 475 |
+
print("β
LLM Loaded\n")
|
| 476 |
+
|
| 477 |
+
def format_docs_with_sources(docs):
|
| 478 |
+
"""
|
| 479 |
+
Combines document content with its metadata (Source File & Page).
|
| 480 |
+
"""
|
| 481 |
+
formatted_entries = []
|
| 482 |
+
for doc in docs:
|
| 483 |
+
# Extract metadata (default to 'Unknown' if missing)
|
| 484 |
+
source = doc.metadata.get("source", "Unknown Document")
|
| 485 |
+
# Optional: Clean the path to just show filename
|
| 486 |
+
# source = source.split("\\")[-1]
|
| 487 |
+
page = doc.metadata.get("page", "?")
|
| 488 |
+
|
| 489 |
+
entry = f"--- REFERENCE: {source} (Page {page}) ---\n{doc.page_content}\n"
|
| 490 |
+
formatted_entries.append(entry)
|
| 491 |
+
|
| 492 |
+
return "\n\n".join(formatted_entries)
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def _build_retrieval_chain():
|
| 496 |
+
global qa_chain
|
| 497 |
+
print("π Building Production RAG Chain (Sources + Hybrid)...")
|
| 498 |
+
|
| 499 |
+
# --- A. RETRIEVER SETUP (Speed Optimized) ---
|
| 500 |
+
|
| 501 |
+
# 1. Vector Search
|
| 502 |
+
faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
|
| 503 |
+
|
| 504 |
+
# 2. BM25 (Keyword Search)
|
| 505 |
+
try:
|
| 506 |
+
all_docs = list(vectorstore.docstore._dict.values())
|
| 507 |
+
bm25_retriever = BM25Retriever.from_documents(all_docs)
|
| 508 |
+
bm25_retriever.k = 10
|
| 509 |
+
|
| 510 |
+
ensemble_retriever = EnsembleRetriever(
|
| 511 |
+
retrievers=[faiss_retriever, bm25_retriever],
|
| 512 |
+
weights=[0.3, 0.7]
|
| 513 |
+
)
|
| 514 |
+
except:
|
| 515 |
+
ensemble_retriever = faiss_retriever
|
| 516 |
+
|
| 517 |
+
# 3. Reranking (Top 5 only)
|
| 518 |
+
try:
|
| 519 |
+
compressor = FlashrankRerank(model="ms-marco-MiniLM-L-12-v2", top_n=5)
|
| 520 |
+
final_retriever = ContextualCompressionRetriever(
|
| 521 |
+
base_compressor=compressor,
|
| 522 |
+
base_retriever=ensemble_retriever
|
| 523 |
+
)
|
| 524 |
+
except:
|
| 525 |
+
final_retriever = ensemble_retriever
|
| 526 |
+
|
| 527 |
+
# --- B. HISTORY AWARENESS ---
|
| 528 |
+
|
| 529 |
+
# Reformulate question based on chat history
|
| 530 |
+
rephrase_prompt = ChatPromptTemplate.from_template(
|
| 531 |
+
"""<s>[INST] Rephrase the follow-up question to be a standalone question.
|
| 532 |
+
Chat History: {chat_history}
|
| 533 |
+
Follow Up Input: {input}
|
| 534 |
+
Standalone question: [/INST]"""
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
history_node = create_history_aware_retriever(
|
| 538 |
+
llm_pipeline,
|
| 539 |
+
final_retriever,
|
| 540 |
+
rephrase_prompt
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
# --- C. FINAL ANSWER GENERATION (With Sources) ---
|
| 544 |
+
|
| 545 |
+
qa_prompt = ChatPromptTemplate.from_template(
|
| 546 |
+
"""[INST] You are a helpful assistant for BPCL-Kochi Refinery.
|
| 547 |
+
Answer the user's question based strictly on the context provided below.
|
| 548 |
+
If the answer is not in the context, say "I don't have that information in the manuals."
|
| 549 |
+
ALWAYS cite the document name for your answer.
|
| 550 |
+
|
| 551 |
+
CONTEXT WITH SOURCES:
|
| 552 |
+
{context}
|
| 553 |
+
|
| 554 |
+
USER QUESTION:
|
| 555 |
+
{input}
|
| 556 |
+
|
| 557 |
+
ANSWER: [/INST]"""
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
# The Chain (No Cache)
|
| 561 |
+
qa_chain = (
|
| 562 |
+
{
|
| 563 |
+
"context": history_node | format_docs_with_sources,
|
| 564 |
+
"input": itemgetter("input"),
|
| 565 |
+
"chat_history": itemgetter("chat_history"),
|
| 566 |
+
}
|
| 567 |
+
| qa_prompt
|
| 568 |
+
| llm_pipeline
|
| 569 |
+
| StrOutputParser()
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
print("β
Production Chain Built (with Citations)\n")
|
| 573 |
+
# -------------------------------------------------------------------------
|
| 574 |
+
# 6. FastAPI App & Endpoints
|
| 575 |
+
# -------------------------------------------------------------------------
|
| 576 |
+
@asynccontextmanager
|
| 577 |
+
async def lifespan(app: FastAPI):
|
| 578 |
+
print("\nπ Starting application (OFFLINE)...")
|
| 579 |
+
load_system()
|
| 580 |
+
logger.info("RAG system initialized (OFFLINE)")
|
| 581 |
+
|
| 582 |
+
yield
|
| 583 |
+
|
| 584 |
+
print("\nπ Shutting down...")
|
| 585 |
+
answer_cache.clear()
|
| 586 |
+
if torch.cuda.is_available():
|
| 587 |
+
torch.cuda.empty_cache()
|
| 588 |
+
logger.info("Shutdown complete")
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
app = FastAPI(
|
| 592 |
+
title="BeRU Chat Assistant - VLM2Vec",
|
| 593 |
+
description="100% Offline RAG system with VLM2Vec embeddings",
|
| 594 |
+
version="2.0-VLM2Vec",
|
| 595 |
+
lifespan=lifespan,
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
app.add_middleware(
|
| 599 |
+
CORSMiddleware,
|
| 600 |
+
allow_origins=["*"],
|
| 601 |
+
allow_credentials=True,
|
| 602 |
+
allow_methods=["*"],
|
| 603 |
+
allow_headers=["*"],
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
class ChatRequest(BaseModel):
|
| 608 |
+
message: str = Field(..., min_length=1, max_length=2000)
|
| 609 |
+
mode: str = "Detailed"
|
| 610 |
+
session_id: Optional[str] = "default"
|
| 611 |
+
include_images: bool = False
|
| 612 |
+
|
| 613 |
+
@field_validator("message")
|
| 614 |
+
@classmethod
|
| 615 |
+
def sanitize_message(cls, v):
|
| 616 |
+
return v.strip()
|
| 617 |
+
|
| 618 |
+
@field_validator("mode")
|
| 619 |
+
@classmethod
|
| 620 |
+
def validate_mode(cls, v):
|
| 621 |
+
if v not in PROMPT_TEMPLATES:
|
| 622 |
+
return "Detailed"
|
| 623 |
+
return v
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
class QueryRequest(BaseModel):
|
| 627 |
+
message: str = Field(..., min_length=1, max_length=2000)
|
| 628 |
+
answer_style: str = "Detailed"
|
| 629 |
+
num_sources: int = Field(default=5, ge=1, le=10)
|
| 630 |
+
|
| 631 |
+
@field_validator("message")
|
| 632 |
+
@classmethod
|
| 633 |
+
def sanitize_message(cls, v):
|
| 634 |
+
return v.strip()
|
| 635 |
+
|
| 636 |
+
@field_validator("answer_style")
|
| 637 |
+
@classmethod
|
| 638 |
+
def validate_style(cls, v):
|
| 639 |
+
if v not in PROMPT_TEMPLATES:
|
| 640 |
+
return "Detailed"
|
| 641 |
+
return v
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
@app.get("/", response_class=HTMLResponse)
|
| 645 |
+
async def root():
|
| 646 |
+
try:
|
| 647 |
+
frontend_path = Path("frontend.html")
|
| 648 |
+
if frontend_path.exists():
|
| 649 |
+
with open(frontend_path, "r", encoding="utf-8") as f:
|
| 650 |
+
return f.read()
|
| 651 |
+
else:
|
| 652 |
+
return f"""
|
| 653 |
+
<html>
|
| 654 |
+
<body>
|
| 655 |
+
<h1>Error: frontend.html not found</h1>
|
| 656 |
+
<p>Please place frontend.html in the same directory as this script</p>
|
| 657 |
+
<p>Current directory: {Path.cwd()}</p>
|
| 658 |
+
</body>
|
| 659 |
+
</html>
|
| 660 |
+
"""
|
| 661 |
+
except Exception as e:
|
| 662 |
+
return f"<html><body><h1>Error loading frontend</h1><p>{str(e)}</p></body></html>"
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
query_semaphore = asyncio.Semaphore(3)
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
@app.post("/api/chat")
|
| 669 |
+
async def chat_endpoint(request: ChatRequest):
|
| 670 |
+
async with query_semaphore:
|
| 671 |
+
try:
|
| 672 |
+
message = request.message
|
| 673 |
+
mode = request.mode
|
| 674 |
+
session_id = request.session_id
|
| 675 |
+
|
| 676 |
+
logger.info(f"Chat Query: {message[:100]} | Mode: {mode}")
|
| 677 |
+
print(f"\n{'=' * 60}")
|
| 678 |
+
print(f"π¬ Chat: {message}")
|
| 679 |
+
print(f" Mode: {mode}")
|
| 680 |
+
print(f" Session: {session_id}")
|
| 681 |
+
|
| 682 |
+
# History Management
|
| 683 |
+
if session_id not in conversations:
|
| 684 |
+
conversations[session_id] = []
|
| 685 |
+
|
| 686 |
+
# Check Cache
|
| 687 |
+
cache_key = f"{message}_{mode}_{session_id}"
|
| 688 |
+
if cache_key in answer_cache:
|
| 689 |
+
print("πΎ Cache hit!")
|
| 690 |
+
cached_response = answer_cache[cache_key]
|
| 691 |
+
conversations[session_id].append(
|
| 692 |
+
{
|
| 693 |
+
"user": message,
|
| 694 |
+
"bot": cached_response["response"],
|
| 695 |
+
"mode": mode,
|
| 696 |
+
}
|
| 697 |
+
)
|
| 698 |
+
return JSONResponse(cached_response)
|
| 699 |
+
|
| 700 |
+
print(f"β±οΈ Generating response (timeout: {GENERATION_TIMEOUT}s)...")
|
| 701 |
+
|
| 702 |
+
# Convert dict history to LangChain Objects (Last 3 turns)
|
| 703 |
+
chat_history_objs = []
|
| 704 |
+
for turn in conversations[session_id][-3:]:
|
| 705 |
+
# Ensure you have these imported from langchain_core.messages
|
| 706 |
+
chat_history_objs.append(HumanMessage(content=turn["user"]))
|
| 707 |
+
chat_history_objs.append(AIMessage(content=turn["bot"]))
|
| 708 |
+
|
| 709 |
+
# Execute Chain
|
| 710 |
+
try:
|
| 711 |
+
result = await asyncio.wait_for(
|
| 712 |
+
asyncio.to_thread(
|
| 713 |
+
qa_chain.invoke,
|
| 714 |
+
{
|
| 715 |
+
"input": message,
|
| 716 |
+
"chat_history": chat_history_objs
|
| 717 |
+
},
|
| 718 |
+
),
|
| 719 |
+
timeout=GENERATION_TIMEOUT,
|
| 720 |
+
)
|
| 721 |
+
except asyncio.TimeoutError:
|
| 722 |
+
return JSONResponse(
|
| 723 |
+
{
|
| 724 |
+
"error": f"Query timeout after {GENERATION_TIMEOUT}s",
|
| 725 |
+
"response": "Sorry, the request took too long. Please try again.",
|
| 726 |
+
},
|
| 727 |
+
status_code=504,
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
# --- CRITICAL FIX START ---
|
| 731 |
+
# The new chain returns a String directly. The old one returned a Dict.
|
| 732 |
+
# We must handle both cases to prevent the AttributeError.
|
| 733 |
+
|
| 734 |
+
context_docs = [] # Default to empty if using string chain
|
| 735 |
+
|
| 736 |
+
if isinstance(result, str):
|
| 737 |
+
# New "Production Chain" path
|
| 738 |
+
answer = result
|
| 739 |
+
# Note: In this mode, citations are embedded in the text string
|
| 740 |
+
# (e.g. "Reference: Manual..."), so we don't have raw docs for the 'sources' list.
|
| 741 |
+
elif isinstance(result, dict):
|
| 742 |
+
# Old "Standard Chain" path
|
| 743 |
+
answer = result.get("answer", "No answer generated")
|
| 744 |
+
context_docs = result.get("context", [])
|
| 745 |
+
else:
|
| 746 |
+
answer = str(result)
|
| 747 |
+
|
| 748 |
+
# Clean up the answer text
|
| 749 |
+
answer = structure_answer(answer, mode)
|
| 750 |
+
# --- CRITICAL FIX END ---
|
| 751 |
+
|
| 752 |
+
# Process Sources (Only populates if context_docs were returned)
|
| 753 |
+
sources = []
|
| 754 |
+
for i, doc in enumerate(context_docs[:5], 1):
|
| 755 |
+
sources.append(
|
| 756 |
+
{
|
| 757 |
+
"index": i,
|
| 758 |
+
"file_name": doc.metadata.get("source", "Unknown"),
|
| 759 |
+
"page": doc.metadata.get("page", "N/A"),
|
| 760 |
+
"snippet": doc.page_content[:200].replace("\n", " "),
|
| 761 |
+
}
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
print(f"β
Response generated: {len(answer)} chars")
|
| 765 |
+
|
| 766 |
+
response_data = {
|
| 767 |
+
"response": answer,
|
| 768 |
+
"sources": sources,
|
| 769 |
+
"mode": mode,
|
| 770 |
+
"cached": False,
|
| 771 |
+
"images": [] # Placeholder for image handling
|
| 772 |
+
}
|
| 773 |
+
|
| 774 |
+
answer_cache[cache_key] = response_data
|
| 775 |
+
|
| 776 |
+
conversations[session_id].append(
|
| 777 |
+
{
|
| 778 |
+
"user": message,
|
| 779 |
+
"bot": answer,
|
| 780 |
+
"mode": mode,
|
| 781 |
+
}
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
logger.info("Chat response completed")
|
| 785 |
+
return JSONResponse(response_data)
|
| 786 |
+
|
| 787 |
+
except Exception as e:
|
| 788 |
+
logger.error(f"Chat error: {e}", exc_info=True)
|
| 789 |
+
print(f"β ERROR: {e}")
|
| 790 |
+
# Ensure traceback is printed to console for debugging
|
| 791 |
+
import traceback
|
| 792 |
+
traceback.print_exc()
|
| 793 |
+
return JSONResponse(
|
| 794 |
+
{
|
| 795 |
+
"error": str(e),
|
| 796 |
+
"response": "Sorry, an internal error occurred. Please check server logs.",
|
| 797 |
+
},
|
| 798 |
+
status_code=500,
|
| 799 |
+
)
|
| 800 |
+
@app.post("/api/query")
|
| 801 |
+
async def query_endpoint(request: QueryRequest):
|
| 802 |
+
chat_request = ChatRequest(
|
| 803 |
+
message=request.message,
|
| 804 |
+
mode=request.answer_style,
|
| 805 |
+
session_id="default",
|
| 806 |
+
)
|
| 807 |
+
response = await chat_endpoint(chat_request)
|
| 808 |
+
data = response.body.decode("utf-8")
|
| 809 |
+
import json
|
| 810 |
+
|
| 811 |
+
json_data = json.loads(data)
|
| 812 |
+
if "response" in json_data:
|
| 813 |
+
json_data["answer"] = json_data.pop("response")
|
| 814 |
+
return JSONResponse(json_data)
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
@app.get("/api/health")
|
| 818 |
+
async def health_check():
|
| 819 |
+
return {
|
| 820 |
+
"status": "ok",
|
| 821 |
+
"mode": "OFFLINE",
|
| 822 |
+
"llm_model": LLM_MODEL,
|
| 823 |
+
"embedding_model": EMBED_MODEL,
|
| 824 |
+
"cuda_available": torch.cuda.is_available(),
|
| 825 |
+
"cache_size": len(answer_cache),
|
| 826 |
+
"active_sessions": len(conversations),
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
@app.get("/api/stats")
|
| 831 |
+
async def get_stats():
|
| 832 |
+
try:
|
| 833 |
+
doc_count = len(vectorstore.docstore._dict) if vectorstore else 0
|
| 834 |
+
except Exception:
|
| 835 |
+
doc_count = "unknown"
|
| 836 |
+
|
| 837 |
+
return {
|
| 838 |
+
"mode": "OFFLINE",
|
| 839 |
+
"documents": doc_count,
|
| 840 |
+
"cache_size": len(answer_cache),
|
| 841 |
+
"active_sessions": len(conversations),
|
| 842 |
+
"llm_model": LLM_MODEL,
|
| 843 |
+
"embedding_model": EMBED_MODEL,
|
| 844 |
+
"cuda_available": torch.cuda.is_available(),
|
| 845 |
+
"index_path": FAISS_INDEX_PATH,
|
| 846 |
+
}
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
@app.post("/api/new-conversation")
|
| 850 |
+
async def new_conversation(request: dict):
|
| 851 |
+
session_id = request.get("session_id", "default")
|
| 852 |
+
if session_id in conversations:
|
| 853 |
+
conversations[session_id] = []
|
| 854 |
+
return {"message": "New conversation started", "session_id": session_id}
|
| 855 |
+
|
| 856 |
+
|
| 857 |
+
@app.get("/api/conversation/{session_id}")
|
| 858 |
+
async def get_conversation(session_id: str):
|
| 859 |
+
if session_id in conversations:
|
| 860 |
+
return {"history": conversations[session_id]}
|
| 861 |
+
return {"history": []}
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
@app.get("/api/clear_cache")
|
| 865 |
+
async def clear_cache():
|
| 866 |
+
cache_size = len(answer_cache)
|
| 867 |
+
answer_cache.clear()
|
| 868 |
+
|
| 869 |
+
if torch.cuda.is_available():
|
| 870 |
+
torch.cuda.empty_cache()
|
| 871 |
+
|
| 872 |
+
return {"message": f"Cache cleared. Removed {cache_size} entries"}
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
if __name__ == "__main__":
|
| 876 |
+
import sys
|
| 877 |
+
import argparse
|
| 878 |
+
|
| 879 |
+
parser = argparse.ArgumentParser()
|
| 880 |
+
parser.add_argument("--port", type=int, default=8001, help="Port to run the server on")
|
| 881 |
+
args = parser.parse_args()
|
| 882 |
+
|
| 883 |
+
port = args.port
|
| 884 |
+
|
| 885 |
+
print("\n" + "=" * 70)
|
| 886 |
+
print("π BeRU Chat Assistant - VLM2Vec Mode (100% OFFLINE)")
|
| 887 |
+
print("=" * 70)
|
| 888 |
+
print(f"\nπ Frontend: http://localhost:{port}")
|
| 889 |
+
print(f"π API Docs: http://localhost:{port}/docs")
|
| 890 |
+
print(f"π Health: http://localhost:{port}/api/health")
|
| 891 |
+
print(f"π Stats: http://localhost:{port}/api/stats")
|
| 892 |
+
print(f"\nπ Embedding Model (LOCAL): {EMBED_MODEL_PATH}")
|
| 893 |
+
print(f"π LLM Model (LOCAL): {LLM_MODEL_PATH}")
|
| 894 |
+
print(f"π FAISS Index: {FAISS_INDEX_PATH}")
|
| 895 |
+
print("π Mode: 100% OFFLINE (local files only)")
|
| 896 |
+
print("=" * 70 + "\n")
|
| 897 |
+
|
| 898 |
+
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")
|
frontend.html
ADDED
|
@@ -0,0 +1,1075 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>BeRU Chat - Multimodal</title>
|
| 7 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css">
|
| 8 |
+
|
| 9 |
+
<style>
|
| 10 |
+
/* *** EXISTING STYLES (keeping all your original styles) *** */
|
| 11 |
+
body {
|
| 12 |
+
font-family: 'Roboto', sans-serif;
|
| 13 |
+
margin: 0;
|
| 14 |
+
padding: 0;
|
| 15 |
+
transition: background-color 0.5s ease, color 0.5s ease;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
.light-mode {
|
| 19 |
+
background-color: #caf2fa;
|
| 20 |
+
color: #333;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
.dark-mode {
|
| 24 |
+
background-color: #1e1e1e;
|
| 25 |
+
color: #f5f5f5;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
.dark-mode .chat-container {
|
| 29 |
+
background-color: #1e1e1e;
|
| 30 |
+
color: #f5f5f5;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
.sidebar {
|
| 34 |
+
width: 300px;
|
| 35 |
+
height: 100vh;
|
| 36 |
+
background-color: #0d131a;
|
| 37 |
+
position: fixed;
|
| 38 |
+
top: 0;
|
| 39 |
+
left: 0;
|
| 40 |
+
z-index: 1;
|
| 41 |
+
overflow-x: hidden;
|
| 42 |
+
transition: width 0.3s, background-color 0.5s ease;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
.light-mode .sidebar {
|
| 46 |
+
background-color: #01414e;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
.sidebar.collapsed {
|
| 50 |
+
width: 50px;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
.tooltip {
|
| 54 |
+
position: absolute;
|
| 55 |
+
top: 0;
|
| 56 |
+
right: -20px;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
.tooltip .tooltiptext {
|
| 60 |
+
visibility: hidden;
|
| 61 |
+
width: 120px;
|
| 62 |
+
background-color: rgb(0, 0, 0);
|
| 63 |
+
color: #fff;
|
| 64 |
+
text-align: center;
|
| 65 |
+
border-radius: 6px;
|
| 66 |
+
padding: 5px 0;
|
| 67 |
+
position: absolute;
|
| 68 |
+
z-index: 1;
|
| 69 |
+
bottom: 125%;
|
| 70 |
+
left: 50%;
|
| 71 |
+
margin-left: -60px;
|
| 72 |
+
opacity: 0;
|
| 73 |
+
transition: opacity 0.3s;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
#sidebar-toggle {
|
| 77 |
+
background-color: transparent;
|
| 78 |
+
margin: -22%;
|
| 79 |
+
align-items: center;
|
| 80 |
+
margin-top: 600%;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
.tooltip .tooltiptext::after {
|
| 84 |
+
content: "";
|
| 85 |
+
position: absolute;
|
| 86 |
+
top: 100%;
|
| 87 |
+
left: 50%;
|
| 88 |
+
border-width: 5px;
|
| 89 |
+
border-style: solid;
|
| 90 |
+
border-color: black transparent transparent transparent;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
.tooltip:hover .tooltiptext {
|
| 94 |
+
visibility: visible;
|
| 95 |
+
opacity: 1;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
.sidebar-content {
|
| 99 |
+
padding-top: 20px;
|
| 100 |
+
transition: opacity 0.3s;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
.sidebar.collapsed .sidebar-content {
|
| 104 |
+
opacity: 0;
|
| 105 |
+
pointer-events: none;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
.conversation-list {
|
| 109 |
+
padding: 0 20px;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
.conversation {
|
| 113 |
+
margin-bottom: 10px;
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
.conversation-text {
|
| 117 |
+
font-weight: bold;
|
| 118 |
+
color: #fff;
|
| 119 |
+
transition: color 0.5s ease;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
.light-mode .conversation-text {
|
| 123 |
+
color: #ccc;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
.conversation-content {
|
| 127 |
+
color: #ddd;
|
| 128 |
+
transition: color 0.5s ease;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
.light-mode .conversation-content {
|
| 132 |
+
color: #888;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
#new-conversation-btn {
|
| 136 |
+
background-color: #3a3b3b;
|
| 137 |
+
color: #fff;
|
| 138 |
+
border: none;
|
| 139 |
+
padding: 10px 20px;
|
| 140 |
+
border-radius: 5px;
|
| 141 |
+
cursor: pointer;
|
| 142 |
+
transition: background-color 0.3s, color 0.5s ease;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
#new-conversation-btn:hover {
|
| 146 |
+
background-color: #242020;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
.light-mode #new-conversation-btn {
|
| 150 |
+
background-color: #c9c9c9;
|
| 151 |
+
color: #171717;
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
.light-mode #new-conversation-btn:hover {
|
| 155 |
+
background-color: #e0e0e0;
|
| 156 |
+
color: #171717;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
.chat-container {
|
| 160 |
+
width: calc(100% - 300px);
|
| 161 |
+
margin-left: 300px;
|
| 162 |
+
height: 100vh;
|
| 163 |
+
overflow: hidden;
|
| 164 |
+
transition: all 0.3s ease-in-out, background-color 0.5s ease;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
.sidebar.collapsed ~ .chat-container {
|
| 168 |
+
width: calc(100% - 50px);
|
| 169 |
+
margin-left: 50px;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
.chat-content {
|
| 173 |
+
display: flex;
|
| 174 |
+
flex-direction: column;
|
| 175 |
+
height: 100%;
|
| 176 |
+
padding-bottom: 80px;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
.logo-container {
|
| 180 |
+
display: flex;
|
| 181 |
+
align-items: center;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
.logo {
|
| 185 |
+
width: 30px;
|
| 186 |
+
height: 30px;
|
| 187 |
+
margin-right: 10px;
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
.chat-header {
|
| 191 |
+
margin-left: 2%;
|
| 192 |
+
display: flex;
|
| 193 |
+
align-items: center;
|
| 194 |
+
justify-content: space-between;
|
| 195 |
+
font-size: 10px;
|
| 196 |
+
height: 60px;
|
| 197 |
+
background-color: #171717;
|
| 198 |
+
transition: background-color 0.5s ease, color 0.5s ease;
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
.light-mode h1 {
|
| 202 |
+
color: black;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
.dark-mode h1 {
|
| 206 |
+
color: #f5f5f5;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
.light-mode .chat-header {
|
| 210 |
+
background-color: #caf2fa;
|
| 211 |
+
color: #333;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
.dark-mode .chat-header {
|
| 215 |
+
background-color: #1e1e1e;
|
| 216 |
+
color: #f5f5f5;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
h1 {
|
| 220 |
+
color: #cfcfcf;
|
| 221 |
+
font-family: 'Trebuchet MS', sans-serif;
|
| 222 |
+
transition: color 0.5s ease;
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
/* Toggle Switch Styles */
|
| 226 |
+
.toggle-switch {
|
| 227 |
+
position: relative;
|
| 228 |
+
width: 50px;
|
| 229 |
+
height: 25px;
|
| 230 |
+
margin-right: 20px;
|
| 231 |
+
--light: #d8dbe0;
|
| 232 |
+
--dark: #28292c;
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
.switch-label {
|
| 236 |
+
position: absolute;
|
| 237 |
+
width: 100%;
|
| 238 |
+
height: 100%;
|
| 239 |
+
background-color: var(--dark);
|
| 240 |
+
border-radius: 12px;
|
| 241 |
+
cursor: pointer;
|
| 242 |
+
border: 1.5px solid var(--dark);
|
| 243 |
+
transition: background-color 0.3s;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
.checkbox {
|
| 247 |
+
position: absolute;
|
| 248 |
+
display: none;
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
.slider {
|
| 252 |
+
position: absolute;
|
| 253 |
+
width: 100%;
|
| 254 |
+
height: 100%;
|
| 255 |
+
border-radius: 12px;
|
| 256 |
+
transition: 0.3s;
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
.checkbox:checked ~ .slider {
|
| 260 |
+
background-color: var(--light);
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
.slider::before {
|
| 264 |
+
content: "";
|
| 265 |
+
position: absolute;
|
| 266 |
+
top: 5.5px;
|
| 267 |
+
left: 5.5px;
|
| 268 |
+
width: 14px;
|
| 269 |
+
height: 14px;
|
| 270 |
+
border-radius: 50%;
|
| 271 |
+
box-shadow: inset 7px -2px 0px 0px var(--light);
|
| 272 |
+
background-color: var(--dark);
|
| 273 |
+
transition: 0.3s;
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
.checkbox:checked ~ .slider::before {
|
| 277 |
+
transform: translateX(26px);
|
| 278 |
+
background-color: var(--dark);
|
| 279 |
+
box-shadow: none;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
.chat-box {
|
| 283 |
+
display: flex;
|
| 284 |
+
flex-direction: column;
|
| 285 |
+
flex: 1;
|
| 286 |
+
overflow-y: auto;
|
| 287 |
+
padding: 15px;
|
| 288 |
+
overflow-x: hidden;
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
.chat-box::-webkit-scrollbar {
|
| 292 |
+
width: 3px;
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
.chat-box::-webkit-scrollbar-track {
|
| 296 |
+
background: transparent;
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
.chat-box::-webkit-scrollbar-track-piece {
|
| 300 |
+
background: #b0b0b000;
|
| 301 |
+
border-radius: 999px;
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
.chat-box::-webkit-scrollbar-thumb {
|
| 305 |
+
background-color: #ffd700;
|
| 306 |
+
border-radius: 999px;
|
| 307 |
+
border: 2px solid transparent;
|
| 308 |
+
background-clip: padding-box;
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
.chat-box {
|
| 312 |
+
scrollbar-width: thin;
|
| 313 |
+
scrollbar-color: #ffd700 #b0b0b0;
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
.chat-box p {
|
| 317 |
+
margin: 10px 0;
|
| 318 |
+
font-size: 16px;
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
.messageBox {
|
| 322 |
+
position: fixed;
|
| 323 |
+
bottom: 20px;
|
| 324 |
+
left: 50%;
|
| 325 |
+
transform: translateX(-50%);
|
| 326 |
+
display: flex;
|
| 327 |
+
align-items: center;
|
| 328 |
+
background-color: #2d2d2d;
|
| 329 |
+
padding: 0 12px;
|
| 330 |
+
border-radius: 10px;
|
| 331 |
+
border: 1px solid rgb(63, 63, 63);
|
| 332 |
+
width: 60%;
|
| 333 |
+
max-width: 800px;
|
| 334 |
+
height: 50px;
|
| 335 |
+
transition: all 0.3s ease-in-out, background-color 0.5s ease, border-color 0.5s ease;
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
.light-mode .messageBox {
|
| 339 |
+
background-color: white;
|
| 340 |
+
border: 1px solid #1d495f;
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
.messageBox:focus-within {
|
| 344 |
+
border: 1px solid rgb(110, 110, 110);
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
#messageInput {
|
| 348 |
+
flex: 1;
|
| 349 |
+
height: 100%;
|
| 350 |
+
background-color: transparent;
|
| 351 |
+
outline: none;
|
| 352 |
+
border: none;
|
| 353 |
+
padding: 0 12px;
|
| 354 |
+
color: white;
|
| 355 |
+
width: auto;
|
| 356 |
+
font-family: 'Roboto', sans-serif;
|
| 357 |
+
font-size: 14px;
|
| 358 |
+
transition: color 0.5s ease;
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
.light-mode #messageInput {
|
| 362 |
+
color: #171717;
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
#sendButton {
|
| 366 |
+
width: 50px;
|
| 367 |
+
height: 100%;
|
| 368 |
+
background-color: transparent;
|
| 369 |
+
outline: none;
|
| 370 |
+
border: none;
|
| 371 |
+
display: flex;
|
| 372 |
+
align-items: center;
|
| 373 |
+
justify-content: center;
|
| 374 |
+
cursor: pointer;
|
| 375 |
+
padding: 0;
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
#sendButton svg {
|
| 379 |
+
height: 60%;
|
| 380 |
+
width: auto;
|
| 381 |
+
transition: all 0.3s;
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
#sendButton svg path {
|
| 385 |
+
transition: all 0.3s;
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
#sendButton:hover svg path,
|
| 389 |
+
#sendButton:active svg path {
|
| 390 |
+
fill: #3c3c3c;
|
| 391 |
+
stroke: white;
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
.light-mode #sendButton:hover svg path,
|
| 395 |
+
.light-mode #sendButton:active svg path {
|
| 396 |
+
fill: #fbf7e7;
|
| 397 |
+
stroke: #ffcd07 !important;
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
.message-row {
|
| 401 |
+
width: 100%;
|
| 402 |
+
margin: 8px 0;
|
| 403 |
+
display: flex;
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
.message-user {
|
| 407 |
+
justify-content: flex-end;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
.message-bot {
|
| 411 |
+
justify-content: flex-start;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
.message-bubble {
|
| 415 |
+
max-width: 70%;
|
| 416 |
+
padding: 10px 14px;
|
| 417 |
+
border-radius: 16px;
|
| 418 |
+
font-size: 14px;
|
| 419 |
+
transition: background-color 0.5s ease, color 0.5s ease;
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
.message-user .message-bubble {
|
| 423 |
+
background-color: #1b798e;
|
| 424 |
+
color: white;
|
| 425 |
+
border-bottom-right-radius: 4px;
|
| 426 |
+
border: 1px solid #FFD700;
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
.message-bot .message-bubble {
|
| 430 |
+
background-color: #2d2d2d;
|
| 431 |
+
color: white;
|
| 432 |
+
border-bottom-left-radius: 4px;
|
| 433 |
+
border: 1px solid #FFD700;
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
.light-mode .message-bot .message-bubble {
|
| 437 |
+
background-color: #fefdf6;
|
| 438 |
+
color: #111;
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
/* β
NEW: Image Gallery Styles */
|
| 442 |
+
.image-gallery {
|
| 443 |
+
display: flex;
|
| 444 |
+
gap: 10px;
|
| 445 |
+
flex-wrap: wrap;
|
| 446 |
+
margin-top: 12px;
|
| 447 |
+
padding-top: 12px;
|
| 448 |
+
border-top: 1px solid rgba(255, 215, 0, 0.3);
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
.image-container {
|
| 452 |
+
position: relative;
|
| 453 |
+
border-radius: 8px;
|
| 454 |
+
overflow: hidden;
|
| 455 |
+
cursor: pointer;
|
| 456 |
+
transition: transform 0.2s ease;
|
| 457 |
+
border: 2px solid #FFD700;
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
.image-container:hover {
|
| 461 |
+
transform: scale(1.05);
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
.image-container img {
|
| 465 |
+
width: 150px;
|
| 466 |
+
height: 150px;
|
| 467 |
+
object-fit: cover;
|
| 468 |
+
display: block;
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
.image-caption {
|
| 472 |
+
position: absolute;
|
| 473 |
+
bottom: 0;
|
| 474 |
+
left: 0;
|
| 475 |
+
right: 0;
|
| 476 |
+
background: rgba(0, 0, 0, 0.7);
|
| 477 |
+
color: #FFD700;
|
| 478 |
+
padding: 4px 8px;
|
| 479 |
+
font-size: 10px;
|
| 480 |
+
text-align: center;
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
/* β
NEW: Image Modal/Lightbox */
|
| 484 |
+
.image-modal {
|
| 485 |
+
display: none;
|
| 486 |
+
position: fixed;
|
| 487 |
+
z-index: 9999;
|
| 488 |
+
left: 0;
|
| 489 |
+
top: 0;
|
| 490 |
+
width: 100%;
|
| 491 |
+
height: 100%;
|
| 492 |
+
background-color: rgba(0, 0, 0, 0.9);
|
| 493 |
+
align-items: center;
|
| 494 |
+
justify-content: center;
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
.image-modal.active {
|
| 498 |
+
display: flex;
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
.modal-content {
|
| 502 |
+
max-width: 90%;
|
| 503 |
+
max-height: 90%;
|
| 504 |
+
border-radius: 8px;
|
| 505 |
+
box-shadow: 0 4px 20px rgba(255, 215, 0, 0.5);
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
.modal-close {
|
| 509 |
+
position: absolute;
|
| 510 |
+
top: 20px;
|
| 511 |
+
right: 35px;
|
| 512 |
+
color: #FFD700;
|
| 513 |
+
font-size: 40px;
|
| 514 |
+
font-weight: bold;
|
| 515 |
+
cursor: pointer;
|
| 516 |
+
transition: color 0.3s;
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
.modal-close:hover {
|
| 520 |
+
color: #fff;
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
.fileUploadWrapper {
|
| 524 |
+
position: relative;
|
| 525 |
+
display: flex;
|
| 526 |
+
align-items: center;
|
| 527 |
+
margin-right: 8px;
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
.fileUploadWrapper label {
|
| 531 |
+
display: flex;
|
| 532 |
+
align-items: center;
|
| 533 |
+
cursor: pointer;
|
| 534 |
+
padding: 0;
|
| 535 |
+
margin: 0;
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
.fileUploadWrapper svg {
|
| 539 |
+
width: 20px;
|
| 540 |
+
height: 20px;
|
| 541 |
+
fill: #6c6c6c;
|
| 542 |
+
transition: all 0.3s ease;
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
.fileUploadWrapper label:hover svg {
|
| 546 |
+
fill: #10a37f;
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
.fileUploadWrapper .tooltip {
|
| 550 |
+
display: none;
|
| 551 |
+
position: absolute;
|
| 552 |
+
bottom: 125%;
|
| 553 |
+
left: 50%;
|
| 554 |
+
transform: translateX(-50%);
|
| 555 |
+
background-color: #000;
|
| 556 |
+
color: #fff;
|
| 557 |
+
padding: 5px 10px;
|
| 558 |
+
border-radius: 6px;
|
| 559 |
+
font-size: 12px;
|
| 560 |
+
white-space: nowrap;
|
| 561 |
+
z-index: 1;
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
.fileUploadWrapper label:hover .tooltip {
|
| 565 |
+
display: block;
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
.fileUploadWrapper input {
|
| 569 |
+
display: none;
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
.fileUploadWrapper,
|
| 573 |
+
#sendButton {
|
| 574 |
+
width: 40px;
|
| 575 |
+
height: 40px;
|
| 576 |
+
display: flex;
|
| 577 |
+
align-items: center;
|
| 578 |
+
justify-content: center;
|
| 579 |
+
margin: 0 8px;
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
.fileUploadWrapper svg,
|
| 583 |
+
#sendButton svg {
|
| 584 |
+
width: 20px;
|
| 585 |
+
height: 20px;
|
| 586 |
+
fill: #6c6c6c;
|
| 587 |
+
transition: all 0.3s ease;
|
| 588 |
+
}
|
| 589 |
+
|
| 590 |
+
.fileUploadWrapper label:hover svg path,
|
| 591 |
+
.fileUploadWrapper label:hover svg circle,
|
| 592 |
+
#sendButton:hover svg path {
|
| 593 |
+
stroke: white;
|
| 594 |
+
transition: stroke 0.3s ease;
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
.light-mode .fileUploadWrapper label:hover svg path,
|
| 598 |
+
.light-mode .fileUploadWrapper label:hover svg circle,
|
| 599 |
+
.light-mode #sendButton:hover svg path {
|
| 600 |
+
stroke: #ffcd07;
|
| 601 |
+
transition: stroke 0.3s ease;
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
#sendButton svg,
|
| 605 |
+
.fileUploadWrapper svg {
|
| 606 |
+
transition: all 0.3s ease;
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
/* Mode Dropdown Styles */
|
| 610 |
+
.mode-dropdown-wrapper {
|
| 611 |
+
position: relative;
|
| 612 |
+
display: flex;
|
| 613 |
+
align-items: center;
|
| 614 |
+
margin: 0 8px;
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
.mode-dropdown-button {
|
| 618 |
+
display: inline-flex;
|
| 619 |
+
justify-content: center;
|
| 620 |
+
align-items: center;
|
| 621 |
+
padding: 6px 12px;
|
| 622 |
+
background-color: transparent;
|
| 623 |
+
border: none;
|
| 624 |
+
color: #e5e5e5;
|
| 625 |
+
font-size: 14px;
|
| 626 |
+
font-family: 'Roboto', sans-serif;
|
| 627 |
+
font-weight: 500;
|
| 628 |
+
cursor: pointer;
|
| 629 |
+
border-radius: 6px;
|
| 630 |
+
white-space: nowrap;
|
| 631 |
+
transition: background-color 0.3s, color 0.5s ease;
|
| 632 |
+
}
|
| 633 |
+
|
| 634 |
+
.light-mode .mode-dropdown-button {
|
| 635 |
+
background-color: #f9f9f9;
|
| 636 |
+
color: #171717;
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
.mode-dropdown-button:hover {
|
| 640 |
+
background-color: rgba(110, 110, 110, 0.12);
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
.light-mode .mode-dropdown-button:hover {
|
| 644 |
+
background-color: #f0f0f0;
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
+
.dropdown-arrow {
|
| 648 |
+
width: 16px;
|
| 649 |
+
height: 16px;
|
| 650 |
+
margin-left: 6px;
|
| 651 |
+
transition: transform 0.2s ease, color 0.5s ease;
|
| 652 |
+
}
|
| 653 |
+
|
| 654 |
+
.mode-dropdown-button[aria-expanded="true"] .dropdown-arrow {
|
| 655 |
+
transform: rotate(180deg);
|
| 656 |
+
}
|
| 657 |
+
|
| 658 |
+
.mode-dropdown-menu {
|
| 659 |
+
position: absolute;
|
| 660 |
+
bottom: 100%;
|
| 661 |
+
right: 0;
|
| 662 |
+
margin-bottom: 8px;
|
| 663 |
+
min-width: 180px;
|
| 664 |
+
background-color: #2d2d2d;
|
| 665 |
+
border: 1px solid rgb(63, 63, 63);
|
| 666 |
+
border-radius: 8px;
|
| 667 |
+
box-shadow: 0px 4px 12px rgba(0, 0, 0, 0.4);
|
| 668 |
+
z-index: 1000;
|
| 669 |
+
opacity: 0;
|
| 670 |
+
transform: translateY(-8px);
|
| 671 |
+
transition: opacity 0.2s ease, transform 0.2s ease, background-color 0.5s ease, border-color 0.5s ease;
|
| 672 |
+
pointer-events: none;
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
.mode-dropdown-menu:not(.hidden) {
|
| 676 |
+
opacity: 1;
|
| 677 |
+
transform: translateY(0);
|
| 678 |
+
pointer-events: auto;
|
| 679 |
+
}
|
| 680 |
+
|
| 681 |
+
.light-mode .mode-dropdown-menu {
|
| 682 |
+
background-color: #ffffff;
|
| 683 |
+
border-color: #1d495f;
|
| 684 |
+
box-shadow: 0px 4px 12px rgba(0, 0, 0, 0.15);
|
| 685 |
+
}
|
| 686 |
+
|
| 687 |
+
.dropdown-items {
|
| 688 |
+
padding: 4px 0;
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
+
.dropdown-item {
|
| 692 |
+
display: flex;
|
| 693 |
+
align-items: center;
|
| 694 |
+
padding: 10px 16px;
|
| 695 |
+
color: #ddd;
|
| 696 |
+
text-decoration: none;
|
| 697 |
+
font-size: 14px;
|
| 698 |
+
font-family: 'Roboto', sans-serif;
|
| 699 |
+
cursor: pointer;
|
| 700 |
+
transition: background-color 0.15s ease, color 0.5s ease;
|
| 701 |
+
}
|
| 702 |
+
|
| 703 |
+
.dropdown-item:first-child {
|
| 704 |
+
border-top-left-radius: 8px;
|
| 705 |
+
border-top-right-radius: 8px;
|
| 706 |
+
}
|
| 707 |
+
|
| 708 |
+
.dropdown-item:last-child {
|
| 709 |
+
border-bottom-left-radius: 8px;
|
| 710 |
+
border-bottom-right-radius: 8px;
|
| 711 |
+
}
|
| 712 |
+
|
| 713 |
+
.dropdown-item:hover {
|
| 714 |
+
background-color: rgba(110, 110, 110, 0.25);
|
| 715 |
+
color: #fff;
|
| 716 |
+
}
|
| 717 |
+
|
| 718 |
+
.light-mode .dropdown-item {
|
| 719 |
+
color: #1f2937;
|
| 720 |
+
}
|
| 721 |
+
|
| 722 |
+
.light-mode .dropdown-item:hover {
|
| 723 |
+
background-color: #f3f4f6;
|
| 724 |
+
color: #1e40af;
|
| 725 |
+
}
|
| 726 |
+
|
| 727 |
+
.dropdown-item .font-semibold {
|
| 728 |
+
font-weight: 600;
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
.hidden {
|
| 732 |
+
display: none;
|
| 733 |
+
}
|
| 734 |
+
</style>
|
| 735 |
+
</head>
|
| 736 |
+
<body>
|
| 737 |
+
<div class="sidebar collapsed">
|
| 738 |
+
<div class="tooltip">
|
| 739 |
+
<span class="tooltiptext">Open Sidebar</span>
|
| 740 |
+
<button id="sidebar-toggle">
|
| 741 |
+
<i class="fas fa-chevron-right"></i>
|
| 742 |
+
</button>
|
| 743 |
+
</div>
|
| 744 |
+
<div class="sidebar-content">
|
| 745 |
+
<div class="conversation-list">
|
| 746 |
+
<div class="conversation">
|
| 747 |
+
<p class="conversation-text">Last Conversation:</p>
|
| 748 |
+
<p class="conversation-content">No conversation yet</p>
|
| 749 |
+
</div>
|
| 750 |
+
</div>
|
| 751 |
+
<button id="new-conversation-btn">Start New Conversation</button>
|
| 752 |
+
</div>
|
| 753 |
+
</div>
|
| 754 |
+
|
| 755 |
+
<div class="chat-container light-mode">
|
| 756 |
+
<div class="chat-content">
|
| 757 |
+
<div class="chat-header">
|
| 758 |
+
<div class="logo-container">
|
| 759 |
+
<h1>BeRU </h1>
|
| 760 |
+
</div>
|
| 761 |
+
<div class="toggle-switch">
|
| 762 |
+
<label class="switch-label">
|
| 763 |
+
<input type="checkbox" id="toggle-checkbox" class="checkbox">
|
| 764 |
+
<span class="slider"></span>
|
| 765 |
+
</label>
|
| 766 |
+
</div>
|
| 767 |
+
</div>
|
| 768 |
+
|
| 769 |
+
<div id="chat-box" class="chat-box"></div>
|
| 770 |
+
|
| 771 |
+
<div class="messageBox">
|
| 772 |
+
<div class="fileUploadWrapper">
|
| 773 |
+
<label for="file">
|
| 774 |
+
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 337 337">
|
| 775 |
+
<circle stroke-width="20" stroke="#6c6c6c" fill="none" r="158.5" cy="168.5" cx="168.5"></circle>
|
| 776 |
+
<path stroke-linecap="round" stroke-width="25" stroke="#6c6c6c" d="M167.759 79V259"></path>
|
| 777 |
+
<path stroke-linecap="round" stroke-width="25" stroke="#6c6c6c" d="M79 167.138H259"></path>
|
| 778 |
+
</svg>
|
| 779 |
+
<span class="tooltip">Add an image</span>
|
| 780 |
+
</label>
|
| 781 |
+
<input type="file" id="file" name="file" />
|
| 782 |
+
</div>
|
| 783 |
+
|
| 784 |
+
<input required="" placeholder="Message..." type="text" id="messageInput" />
|
| 785 |
+
|
| 786 |
+
<div class="mode-dropdown-wrapper">
|
| 787 |
+
<button id="modeDropdownButton" type="button" class="mode-dropdown-button" aria-expanded="false" aria-haspopup="true">
|
| 788 |
+
Detailed
|
| 789 |
+
<svg class="dropdown-arrow" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" aria-hidden="true">
|
| 790 |
+
<path fill-rule="evenodd" d="M5.293 7.293a1 1 0 011.414 0L10 10.586l3.293-3.293a1 1 0 111.414 1.414l-4 4a1 1 0 01-1.414 0l-4-4a1 1 0 010-1.414z" clip-rule="evenodd" />
|
| 791 |
+
</svg>
|
| 792 |
+
</button>
|
| 793 |
+
<div id="modeDropdownMenu" class="mode-dropdown-menu hidden" role="menu" aria-orientation="vertical" aria-labelledby="modeDropdownButton" tabindex="-1">
|
| 794 |
+
<div class="dropdown-items">
|
| 795 |
+
<a href="#" class="dropdown-item" role="menuitem" tabindex="-1">
|
| 796 |
+
<span class="font-semibold">Short and Concise</span>
|
| 797 |
+
</a>
|
| 798 |
+
<a href="#" class="dropdown-item" role="menuitem" tabindex="-1">
|
| 799 |
+
<span class="font-semibold">Detailed</span>
|
| 800 |
+
</a>
|
| 801 |
+
<a href="#" class="dropdown-item" role="menuitem" tabindex="-1">
|
| 802 |
+
<span class="font-semibold">Step-by-Step</span>
|
| 803 |
+
</a>
|
| 804 |
+
</div>
|
| 805 |
+
</div>
|
| 806 |
+
</div>
|
| 807 |
+
|
| 808 |
+
<button id="sendButton">
|
| 809 |
+
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 664 663">
|
| 810 |
+
<path fill="none" d="M646.293 331.888L17.7538 17.6187L155.245 331.888M646.293 331.888L17.753 646.157L155.245 331.888M646.293 331.888L318.735 330.228L155.245 331.888"></path>
|
| 811 |
+
<path stroke-linejoin="round" stroke-linecap="round" stroke-width="33.67" stroke="#6c6c6c" d="M646.293 331.888L17.7538 17.6187L155.245 331.888M646.293 331.888L17.753 646.157L155.245 331.888M646.293 331.888L318.735 330.228L155.245 331.888"></path>
|
| 812 |
+
</svg>
|
| 813 |
+
</button>
|
| 814 |
+
</div>
|
| 815 |
+
</div>
|
| 816 |
+
</div>
|
| 817 |
+
|
| 818 |
+
<!-- β
NEW: Image Modal -->
|
| 819 |
+
<div id="imageModal" class="image-modal">
|
| 820 |
+
<span class="modal-close" id="modalClose">×</span>
|
| 821 |
+
<img class="modal-content" id="modalImage">
|
| 822 |
+
</div>
|
| 823 |
+
|
| 824 |
+
<script>
|
| 825 |
+
const chatBox = document.getElementById('chat-box');
|
| 826 |
+
const userInput = document.getElementById('messageInput');
|
| 827 |
+
const sendButton = document.getElementById('sendButton');
|
| 828 |
+
const sidebarToggle = document.getElementById('sidebar-toggle');
|
| 829 |
+
const modeToggle = document.getElementById('toggle-checkbox');
|
| 830 |
+
const sidebar = document.querySelector('.sidebar');
|
| 831 |
+
const chatContainer = document.querySelector('.chat-container');
|
| 832 |
+
const messageBox = document.querySelector('.messageBox');
|
| 833 |
+
const imageModal = document.getElementById('imageModal');
|
| 834 |
+
const modalImage = document.getElementById('modalImage');
|
| 835 |
+
const modalClose = document.getElementById('modalClose');
|
| 836 |
+
|
| 837 |
+
let currentMode = 'Detailed';
|
| 838 |
+
let sessionId = 'session-' + Date.now();
|
| 839 |
+
|
| 840 |
+
// β
NEW: Image Modal Functions
|
| 841 |
+
function openImageModal(imageSrc) {
|
| 842 |
+
imageModal.classList.add('active');
|
| 843 |
+
modalImage.src = imageSrc;
|
| 844 |
+
}
|
| 845 |
+
|
| 846 |
+
function closeImageModal() {
|
| 847 |
+
imageModal.classList.remove('active');
|
| 848 |
+
modalImage.src = '';
|
| 849 |
+
}
|
| 850 |
+
|
| 851 |
+
modalClose.addEventListener('click', closeImageModal);
|
| 852 |
+
imageModal.addEventListener('click', (e) => {
|
| 853 |
+
if (e.target === imageModal) {
|
| 854 |
+
closeImageModal();
|
| 855 |
+
}
|
| 856 |
+
});
|
| 857 |
+
|
| 858 |
+
// Utility Functions
|
| 859 |
+
function getCurrentTime() {
|
| 860 |
+
const now = new Date();
|
| 861 |
+
return now.toLocaleTimeString();
|
| 862 |
+
}
|
| 863 |
+
|
| 864 |
+
// β
MODIFIED: Chat Functions with Image Support
|
| 865 |
+
function appendMessage(sender, message, images = []) {
|
| 866 |
+
const wrapper = document.createElement('div');
|
| 867 |
+
wrapper.classList.add('message-row');
|
| 868 |
+
wrapper.classList.add(sender === 'user' ? 'message-user' : 'message-bot');
|
| 869 |
+
|
| 870 |
+
const bubble = document.createElement('div');
|
| 871 |
+
bubble.classList.add('message-bubble');
|
| 872 |
+
|
| 873 |
+
// Add text message
|
| 874 |
+
bubble.innerHTML = message.replace(/\n/g, '<br>');
|
| 875 |
+
|
| 876 |
+
// β
NEW: Add images if present
|
| 877 |
+
if (images && images.length > 0) {
|
| 878 |
+
const gallery = document.createElement('div');
|
| 879 |
+
gallery.classList.add('image-gallery');
|
| 880 |
+
|
| 881 |
+
images.forEach((img, index) => {
|
| 882 |
+
const imgContainer = document.createElement('div');
|
| 883 |
+
imgContainer.classList.add('image-container');
|
| 884 |
+
|
| 885 |
+
const imgElement = document.createElement('img');
|
| 886 |
+
imgElement.src = img.data;
|
| 887 |
+
imgElement.alt = `Image from ${img.source}`;
|
| 888 |
+
imgElement.loading = 'lazy';
|
| 889 |
+
|
| 890 |
+
// Click to open modal
|
| 891 |
+
imgElement.addEventListener('click', () => {
|
| 892 |
+
openImageModal(img.data);
|
| 893 |
+
});
|
| 894 |
+
|
| 895 |
+
const caption = document.createElement('div');
|
| 896 |
+
caption.classList.add('image-caption');
|
| 897 |
+
caption.textContent = `π ${img.source} | Page ${img.page}`;
|
| 898 |
+
|
| 899 |
+
imgContainer.appendChild(imgElement);
|
| 900 |
+
imgContainer.appendChild(caption);
|
| 901 |
+
gallery.appendChild(imgContainer);
|
| 902 |
+
});
|
| 903 |
+
|
| 904 |
+
bubble.appendChild(gallery);
|
| 905 |
+
}
|
| 906 |
+
|
| 907 |
+
wrapper.appendChild(bubble);
|
| 908 |
+
chatBox.appendChild(wrapper);
|
| 909 |
+
chatBox.scrollTop = chatBox.scrollHeight;
|
| 910 |
+
}
|
| 911 |
+
|
| 912 |
+
async function sendMessage() {
|
| 913 |
+
const message = userInput.value.trim();
|
| 914 |
+
if (message === '') return;
|
| 915 |
+
|
| 916 |
+
// Append user message
|
| 917 |
+
appendMessage('user', message);
|
| 918 |
+
userInput.value = '';
|
| 919 |
+
|
| 920 |
+
// Show loading message
|
| 921 |
+
appendMessage('ChatGPT', 'β³ Thinking...');
|
| 922 |
+
|
| 923 |
+
try {
|
| 924 |
+
const response = await fetch('/api/chat', {
|
| 925 |
+
method: 'POST',
|
| 926 |
+
headers: {
|
| 927 |
+
'Content-Type': 'application/json',
|
| 928 |
+
},
|
| 929 |
+
body: JSON.stringify({
|
| 930 |
+
message: message,
|
| 931 |
+
mode: currentMode,
|
| 932 |
+
session_id: sessionId,
|
| 933 |
+
include_images: true // β
NEW: Request images
|
| 934 |
+
})
|
| 935 |
+
});
|
| 936 |
+
|
| 937 |
+
const data = await response.json();
|
| 938 |
+
|
| 939 |
+
// Remove loading message
|
| 940 |
+
chatBox.removeChild(chatBox.lastChild);
|
| 941 |
+
|
| 942 |
+
if (data.error) {
|
| 943 |
+
appendMessage('ChatGPT', 'β Error: ' + data.error);
|
| 944 |
+
} else {
|
| 945 |
+
// β
NEW: Pass images to appendMessage
|
| 946 |
+
const images = data.images || [];
|
| 947 |
+
appendMessage('ChatGPT', data.response, images);
|
| 948 |
+
|
| 949 |
+
// Log image info
|
| 950 |
+
if (images.length > 0) {
|
| 951 |
+
console.log(`π· Received ${images.length} images`);
|
| 952 |
+
}
|
| 953 |
+
}
|
| 954 |
+
} catch (error) {
|
| 955 |
+
// Remove loading message
|
| 956 |
+
chatBox.removeChild(chatBox.lastChild);
|
| 957 |
+
appendMessage('ChatGPT', 'β Connection error. Please check if the server is running.');
|
| 958 |
+
console.error('Error:', error);
|
| 959 |
+
}
|
| 960 |
+
}
|
| 961 |
+
|
| 962 |
+
// Event Listeners
|
| 963 |
+
document.addEventListener('DOMContentLoaded', function() {
|
| 964 |
+
const newConversationBtn = document.getElementById('new-conversation-btn');
|
| 965 |
+
const conversationContent = document.querySelector('.conversation-content');
|
| 966 |
+
|
| 967 |
+
// Sidebar Toggle
|
| 968 |
+
sidebarToggle.addEventListener('click', function() {
|
| 969 |
+
sidebar.classList.toggle('collapsed');
|
| 970 |
+
adjustMessageBoxPosition();
|
| 971 |
+
});
|
| 972 |
+
|
| 973 |
+
function adjustMessageBoxPosition() {
|
| 974 |
+
const sidebarWidth = sidebar.classList.contains('collapsed') ? 50 : 300;
|
| 975 |
+
const chatAreaWidth = window.innerWidth - sidebarWidth;
|
| 976 |
+
messageBox.style.left = sidebarWidth + (chatAreaWidth / 2) + 'px';
|
| 977 |
+
messageBox.style.transform = 'translateX(-50%)';
|
| 978 |
+
}
|
| 979 |
+
|
| 980 |
+
adjustMessageBoxPosition();
|
| 981 |
+
window.addEventListener('resize', adjustMessageBoxPosition);
|
| 982 |
+
|
| 983 |
+
// New Conversation
|
| 984 |
+
newConversationBtn.addEventListener('click', async function() {
|
| 985 |
+
conversationContent.textContent = 'New Conversation Started!';
|
| 986 |
+
chatBox.innerHTML = '';
|
| 987 |
+
sessionId = 'session-' + Date.now();
|
| 988 |
+
|
| 989 |
+
try {
|
| 990 |
+
await fetch('/api/new-conversation', {
|
| 991 |
+
method: 'POST',
|
| 992 |
+
headers: {'Content-Type': 'application/json'},
|
| 993 |
+
body: JSON.stringify({session_id: sessionId})
|
| 994 |
+
});
|
| 995 |
+
} catch (error) {
|
| 996 |
+
console.error('Error starting new conversation:', error);
|
| 997 |
+
}
|
| 998 |
+
|
| 999 |
+
adjustMessageBoxPosition();
|
| 1000 |
+
});
|
| 1001 |
+
|
| 1002 |
+
// Theme Toggle
|
| 1003 |
+
modeToggle.addEventListener('change', function() {
|
| 1004 |
+
document.body.classList.toggle('dark-mode');
|
| 1005 |
+
document.body.classList.toggle('light-mode');
|
| 1006 |
+
chatContainer.classList.toggle('light-mode');
|
| 1007 |
+
chatContainer.classList.toggle('dark-mode');
|
| 1008 |
+
adjustMessageBoxPosition();
|
| 1009 |
+
});
|
| 1010 |
+
|
| 1011 |
+
// Set initial mode
|
| 1012 |
+
document.body.classList.add('light-mode');
|
| 1013 |
+
|
| 1014 |
+
// Send button
|
| 1015 |
+
sendButton.addEventListener('click', sendMessage);
|
| 1016 |
+
userInput.addEventListener('keydown', (event) => {
|
| 1017 |
+
if (event.key === 'Enter') sendMessage();
|
| 1018 |
+
});
|
| 1019 |
+
|
| 1020 |
+
// Mode Dropdown
|
| 1021 |
+
const modeDropdownButton = document.getElementById('modeDropdownButton');
|
| 1022 |
+
const modeDropdownMenu = document.getElementById('modeDropdownMenu');
|
| 1023 |
+
|
| 1024 |
+
if (modeDropdownButton && modeDropdownMenu) {
|
| 1025 |
+
function toggleMenu() {
|
| 1026 |
+
const isHidden = modeDropdownMenu.classList.contains('hidden');
|
| 1027 |
+
if (isHidden) {
|
| 1028 |
+
modeDropdownMenu.classList.remove('hidden');
|
| 1029 |
+
modeDropdownButton.setAttribute('aria-expanded', 'true');
|
| 1030 |
+
} else {
|
| 1031 |
+
modeDropdownMenu.classList.add('hidden');
|
| 1032 |
+
modeDropdownButton.setAttribute('aria-expanded', 'false');
|
| 1033 |
+
}
|
| 1034 |
+
}
|
| 1035 |
+
|
| 1036 |
+
modeDropdownButton.addEventListener('click', (event) => {
|
| 1037 |
+
event.stopPropagation();
|
| 1038 |
+
toggleMenu();
|
| 1039 |
+
});
|
| 1040 |
+
|
| 1041 |
+
document.addEventListener('click', (event) => {
|
| 1042 |
+
if (!modeDropdownButton.contains(event.target) && !modeDropdownMenu.contains(event.target)) {
|
| 1043 |
+
if (!modeDropdownMenu.classList.contains('hidden')) {
|
| 1044 |
+
toggleMenu();
|
| 1045 |
+
}
|
| 1046 |
+
}
|
| 1047 |
+
});
|
| 1048 |
+
|
| 1049 |
+
modeDropdownMenu.querySelectorAll('.dropdown-item').forEach(item => {
|
| 1050 |
+
item.addEventListener('click', (event) => {
|
| 1051 |
+
event.preventDefault();
|
| 1052 |
+
const selectedMode = event.currentTarget.querySelector('.font-semibold').textContent.trim();
|
| 1053 |
+
console.log('Mode selected:', selectedMode);
|
| 1054 |
+
|
| 1055 |
+
currentMode = selectedMode;
|
| 1056 |
+
|
| 1057 |
+
const buttonTextNode = Array.from(modeDropdownButton.childNodes).find(node =>
|
| 1058 |
+
node.nodeType === Node.TEXT_NODE && node.textContent.trim() !== ''
|
| 1059 |
+
);
|
| 1060 |
+
|
| 1061 |
+
if (buttonTextNode) {
|
| 1062 |
+
buttonTextNode.textContent = selectedMode;
|
| 1063 |
+
} else {
|
| 1064 |
+
const textNode = document.createTextNode(selectedMode);
|
| 1065 |
+
modeDropdownButton.insertBefore(textNode, modeDropdownButton.querySelector('.dropdown-arrow'));
|
| 1066 |
+
}
|
| 1067 |
+
|
| 1068 |
+
toggleMenu();
|
| 1069 |
+
});
|
| 1070 |
+
});
|
| 1071 |
+
}
|
| 1072 |
+
});
|
| 1073 |
+
</script>
|
| 1074 |
+
</body>
|
| 1075 |
+
</html>
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit==1.28.0
|
| 2 |
+
torch==2.0.0
|
| 3 |
+
transformers==4.36.0
|
| 4 |
+
langchain==0.1.0
|
| 5 |
+
langchain-community==0.0.10
|
| 6 |
+
langchain-core==0.1.8
|
| 7 |
+
faiss-cpu==1.7.4
|
| 8 |
+
pydantic==2.5.0
|
| 9 |
+
numpy==1.24.3
|
| 10 |
+
dill==0.3.7
|
| 11 |
+
bitsandbytes==0.41.1
|
| 12 |
+
flashrank==0.2.0
|
| 13 |
+
PyMuPDF==1.23.8
|
| 14 |
+
Pillow==10.0.1
|
| 15 |
+
pytesseract==0.3.10
|
| 16 |
+
pdf2image==1.16.3
|
| 17 |
+
rank-bm25==0.2.2
|
| 18 |
+
huggingface-hub==0.18.0
|
| 19 |
+
peft==0.4.0
|
spaces_app.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BeRU RAG Chat App - Optimized for Hugging Face Spaces
|
| 3 |
+
Deployment: https://huggingface.co/spaces/AnwinMJ/Beru
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import streamlit as st
|
| 7 |
+
import torch
|
| 8 |
+
import os
|
| 9 |
+
import pickle
|
| 10 |
+
import faiss
|
| 11 |
+
import numpy as np
|
| 12 |
+
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
| 13 |
+
from typing import List, Dict
|
| 14 |
+
import time
|
| 15 |
+
import logging
|
| 16 |
+
|
| 17 |
+
# Setup logging
|
| 18 |
+
logging.basicConfig(level=logging.INFO)
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
# ========================================
|
| 22 |
+
# π¨ STREAMLIT PAGE CONFIG
|
| 23 |
+
# ========================================
|
| 24 |
+
st.set_page_config(
|
| 25 |
+
page_title="BeRU Chat - RAG Assistant",
|
| 26 |
+
page_icon="π€",
|
| 27 |
+
layout="wide",
|
| 28 |
+
initial_sidebar_state="expanded",
|
| 29 |
+
menu_items={
|
| 30 |
+
"About": "BeRU - Offline RAG System with VLM2Vec and Mistral 7B"
|
| 31 |
+
}
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# ========================================
|
| 35 |
+
# π ENVIRONMENT DETECTION
|
| 36 |
+
# ========================================
|
| 37 |
+
def detect_environment():
|
| 38 |
+
"""Detect if running on HF Spaces"""
|
| 39 |
+
is_spaces = os.getenv('SPACES', 'false').lower() == 'true' or 'huggingface' in os.path.exists('/app')
|
| 40 |
+
return {
|
| 41 |
+
'is_spaces': is_spaces,
|
| 42 |
+
'device': 'cuda' if torch.cuda.is_available() else 'cpu',
|
| 43 |
+
'model_cache': os.getenv('HF_HOME', './cache'),
|
| 44 |
+
'gpu_memory': torch.cuda.get_device_properties(0).total_memory if torch.cuda.is_available() else 0
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
env_info = detect_environment()
|
| 48 |
+
|
| 49 |
+
# Display environment info in sidebar
|
| 50 |
+
with st.sidebar:
|
| 51 |
+
st.write("### System Info")
|
| 52 |
+
st.write(f"π₯οΈ Device: `{env_info['device'].upper()}`")
|
| 53 |
+
if env_info['device'] == 'cuda':
|
| 54 |
+
st.write(f"πΎ GPU VRAM: `{env_info['gpu_memory'] / 1e9:.1f} GB`")
|
| 55 |
+
st.write(f"π¦ Cache: `{env_info['model_cache']}`")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ========================================
|
| 59 |
+
# π― MODEL LOADING WITH CACHING
|
| 60 |
+
# ========================================
|
| 61 |
+
@st.cache_resource
|
| 62 |
+
def load_embedding_model():
|
| 63 |
+
"""Load VLM2Vec embedding model with error handling"""
|
| 64 |
+
with st.spinner("β³ Loading embedding model... (first time may take 5 min)"):
|
| 65 |
+
try:
|
| 66 |
+
logger.info("Loading VLM2Vec model...")
|
| 67 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 68 |
+
|
| 69 |
+
model = AutoModel.from_pretrained(
|
| 70 |
+
"TIGER-Lab/VLM2Vec-Qwen2VL-2B",
|
| 71 |
+
trust_remote_code=True,
|
| 72 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 73 |
+
cache_dir=env_info['model_cache']
|
| 74 |
+
).to(device)
|
| 75 |
+
|
| 76 |
+
processor = AutoProcessor.from_pretrained(
|
| 77 |
+
"TIGER-Lab/VLM2Vec-Qwen2VL-2B",
|
| 78 |
+
trust_remote_code=True,
|
| 79 |
+
cache_dir=env_info['model_cache']
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 83 |
+
"TIGER-Lab/VLM2Vec-Qwen2VL-2B",
|
| 84 |
+
trust_remote_code=True,
|
| 85 |
+
cache_dir=env_info['model_cache']
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
model.eval()
|
| 89 |
+
logger.info("β
Embedding model loaded successfully")
|
| 90 |
+
st.success("β
Embedding model loaded!")
|
| 91 |
+
|
| 92 |
+
return model, processor, tokenizer, device
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
st.error(f"β Error loading embedding model: {str(e)}")
|
| 96 |
+
logger.error(f"Model loading error: {e}")
|
| 97 |
+
raise
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@st.cache_resource
|
| 101 |
+
def load_llm_model():
|
| 102 |
+
"""Load Mistral 7B LLM with quantization"""
|
| 103 |
+
with st.spinner("β³ Loading LLM model... (first time may take 5 min)"):
|
| 104 |
+
try:
|
| 105 |
+
logger.info("Loading Mistral-7B model...")
|
| 106 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 107 |
+
|
| 108 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 109 |
+
|
| 110 |
+
# 4-bit quantization config for memory efficiency
|
| 111 |
+
bnb_config = BitsAndBytesConfig(
|
| 112 |
+
load_in_4bit=True,
|
| 113 |
+
bnb_4bit_use_double_quant=True,
|
| 114 |
+
bnb_4bit_quant_type="nf4",
|
| 115 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 119 |
+
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 120 |
+
cache_dir=env_info['model_cache']
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 124 |
+
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 125 |
+
quantization_config=bnb_config,
|
| 126 |
+
device_map="auto",
|
| 127 |
+
cache_dir=env_info['model_cache']
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
logger.info("β
LLM model loaded successfully")
|
| 131 |
+
st.success("β
LLM model loaded!")
|
| 132 |
+
|
| 133 |
+
return model, tokenizer, device
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
st.error(f"β Error loading LLM: {str(e)}")
|
| 137 |
+
logger.error(f"LLM loading error: {e}")
|
| 138 |
+
raise
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ========================================
|
| 142 |
+
# π UI LAYOUT
|
| 143 |
+
# ========================================
|
| 144 |
+
st.title("οΏ½οΏ½ BeRU Chat - RAG Assistant")
|
| 145 |
+
st.markdown("""
|
| 146 |
+
A powerful offline RAG system combining Mistral 7B LLM with VLM2Vec embeddings
|
| 147 |
+
for intelligent document search and conversation.
|
| 148 |
+
|
| 149 |
+
**Status**: Models loading on first access (5-8 minutes)
|
| 150 |
+
""")
|
| 151 |
+
|
| 152 |
+
# Load models
|
| 153 |
+
try:
|
| 154 |
+
embedding_model, processor, tokenizer, device = load_embedding_model()
|
| 155 |
+
llm_model, llm_tokenizer, llm_device = load_llm_model()
|
| 156 |
+
models_loaded = True
|
| 157 |
+
except Exception as e:
|
| 158 |
+
st.error(f"Failed to load models: {str(e)}")
|
| 159 |
+
models_loaded = False
|
| 160 |
+
|
| 161 |
+
if models_loaded:
|
| 162 |
+
# Main chat interface
|
| 163 |
+
left_col, right_col = st.columns([2, 1])
|
| 164 |
+
|
| 165 |
+
with left_col:
|
| 166 |
+
st.subheader("π¬ Chat")
|
| 167 |
+
|
| 168 |
+
# Initialize session state
|
| 169 |
+
if "messages" not in st.session_state:
|
| 170 |
+
st.session_state.messages = []
|
| 171 |
+
|
| 172 |
+
# Display chat history
|
| 173 |
+
for msg in st.session_state.messages:
|
| 174 |
+
with st.chat_message(msg["role"]):
|
| 175 |
+
st.write(msg["content"])
|
| 176 |
+
|
| 177 |
+
# Chat input
|
| 178 |
+
user_input = st.chat_input("Ask a question about your documents...")
|
| 179 |
+
|
| 180 |
+
if user_input:
|
| 181 |
+
# Add user message
|
| 182 |
+
st.session_state.messages.append({"role": "user", "content": user_input})
|
| 183 |
+
|
| 184 |
+
with st.chat_message("user"):
|
| 185 |
+
st.write(user_input)
|
| 186 |
+
|
| 187 |
+
# Generate response
|
| 188 |
+
with st.chat_message("assistant"):
|
| 189 |
+
with st.spinner("π€ Thinking..."):
|
| 190 |
+
# Placeholder for RAG response
|
| 191 |
+
response = "Response generated from RAG system..."
|
| 192 |
+
st.write(response)
|
| 193 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
| 194 |
+
|
| 195 |
+
with right_col:
|
| 196 |
+
st.subheader("π Info")
|
| 197 |
+
st.info("""
|
| 198 |
+
**Model Info:**
|
| 199 |
+
- π§ Embedding: VLM2Vec-Qwen2VL-2B
|
| 200 |
+
- π¬ LLM: Mistral-7B-Instruct
|
| 201 |
+
- π Search: FAISS + BM25
|
| 202 |
+
|
| 203 |
+
**Performance:**
|
| 204 |
+
- Device: GPU if available
|
| 205 |
+
- Quantization: 4-bit
|
| 206 |
+
- Context: Multi-turn
|
| 207 |
+
""")
|
| 208 |
+
|
| 209 |
+
st.subheader("βοΈ Settings")
|
| 210 |
+
temperature = st.slider("Temperature", 0.0, 1.0, 0.7)
|
| 211 |
+
max_tokens = st.slider("Max Tokens", 100, 2000, 512)
|
| 212 |
+
|
| 213 |
+
else:
|
| 214 |
+
st.error("β Failed to initialize models. Check logs for details.")
|
| 215 |
+
st.info("Try refreshing the page or restarting the Space.")
|
| 216 |
+
|
| 217 |
+
# ========================================
|
| 218 |
+
# π FOOTER
|
| 219 |
+
# ========================================
|
| 220 |
+
st.markdown("---")
|
| 221 |
+
st.markdown("""
|
| 222 |
+
<div style='text-align: center'>
|
| 223 |
+
<small>
|
| 224 |
+
BeRU RAG System |
|
| 225 |
+
<a href='https://huggingface.co/spaces/AnwinMJ/Beru'>Space</a> |
|
| 226 |
+
<a href='https://github.com/AnwinMJ/BeRU'>GitHub</a>
|
| 227 |
+
</small>
|
| 228 |
+
</div>
|
| 229 |
+
""", unsafe_allow_html=True)
|
vlm2rag2.py
ADDED
|
@@ -0,0 +1,1354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import glob
|
| 3 |
+
import os
|
| 4 |
+
import gc
|
| 5 |
+
import time
|
| 6 |
+
import re
|
| 7 |
+
import hashlib
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import List, Dict, Tuple, Optional
|
| 10 |
+
import fitz # PyMuPDF
|
| 11 |
+
import torch
|
| 12 |
+
import numpy as np
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from transformers import AutoModel, AutoProcessor, AutoTokenizer # Changed from AutoModelForCausalLM
|
| 15 |
+
from langchain_core.documents import Document
|
| 16 |
+
import pickle
|
| 17 |
+
from numpy.linalg import norm
|
| 18 |
+
import camelot
|
| 19 |
+
import base64
|
| 20 |
+
import pytesseract
|
| 21 |
+
from pdf2image import convert_from_path
|
| 22 |
+
import faiss
|
| 23 |
+
from rank_bm25 import BM25Okapi
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ========================================
|
| 27 |
+
# π CONFIGURATION
|
| 28 |
+
# ========================================
|
| 29 |
+
PDF_DIR = r"D:\BeRU\testing"
|
| 30 |
+
FAISS_INDEX_PATH = "VLM2Vec-V2rag2"
|
| 31 |
+
MODEL_CACHE_DIR = ".cache"
|
| 32 |
+
IMAGE_OUTPUT_DIR = "extracted_images2"
|
| 33 |
+
|
| 34 |
+
# Chunking configuration
|
| 35 |
+
CHUNK_SIZE = 450 # words
|
| 36 |
+
OVERLAP = 100 # words
|
| 37 |
+
MIN_CHUNK_SIZE = 50
|
| 38 |
+
MAX_CHUNK_SIZE = 800
|
| 39 |
+
|
| 40 |
+
# Instruction prefixes for better embeddings
|
| 41 |
+
DOCUMENT_INSTRUCTION = "Represent this technical document for semantic search: "
|
| 42 |
+
QUERY_INSTRUCTION = "Represent this question for finding relevant technical information: "
|
| 43 |
+
|
| 44 |
+
# Hybrid search weights
|
| 45 |
+
DENSE_WEIGHT = 0.4 # Weight for semantic search
|
| 46 |
+
SPARSE_WEIGHT = 0.6 # Weight for keyword search
|
| 47 |
+
|
| 48 |
+
# Create directories
|
| 49 |
+
os.makedirs(PDF_DIR, exist_ok=True)
|
| 50 |
+
os.makedirs(FAISS_INDEX_PATH, exist_ok=True)
|
| 51 |
+
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
|
| 52 |
+
os.makedirs(IMAGE_OUTPUT_DIR, exist_ok=True)
|
| 53 |
+
|
| 54 |
+
# ========================================
|
| 55 |
+
# π€ VLM2Vec-V2 WRAPPER (ENHANCED)
|
| 56 |
+
# ========================================
|
| 57 |
+
class VLM2VecEmbeddings:
|
| 58 |
+
"""VLM2Vec-V2 embedding class with instruction prefixes."""
|
| 59 |
+
|
| 60 |
+
def __init__(self, model_name: str = "TIGER-Lab/VLM2Vec-Qwen2VL-2B", cache_dir: str = None):
|
| 61 |
+
print(f"π€ Loading VLM2Vec-V2 model: {model_name}")
|
| 62 |
+
|
| 63 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 64 |
+
print(f" Device: {self.device}")
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
self.model = AutoModel.from_pretrained(
|
| 68 |
+
model_name,
|
| 69 |
+
cache_dir=cache_dir,
|
| 70 |
+
trust_remote_code=True,
|
| 71 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
| 72 |
+
).to(self.device)
|
| 73 |
+
|
| 74 |
+
self.processor = AutoProcessor.from_pretrained(
|
| 75 |
+
model_name,
|
| 76 |
+
cache_dir=cache_dir,
|
| 77 |
+
trust_remote_code=True
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 81 |
+
model_name,
|
| 82 |
+
cache_dir=cache_dir,
|
| 83 |
+
trust_remote_code=True
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
self.model.eval()
|
| 87 |
+
|
| 88 |
+
# Get actual embedding dimension
|
| 89 |
+
test_input = self.tokenizer("test", return_tensors="pt").to(self.device)
|
| 90 |
+
with torch.no_grad():
|
| 91 |
+
test_output = self.model(**test_input, output_hidden_states=True)
|
| 92 |
+
self.embedding_dim = test_output.hidden_states[-1].shape[-1]
|
| 93 |
+
|
| 94 |
+
print(f" Embedding dimension: {self.embedding_dim}")
|
| 95 |
+
print("β
VLM2Vec-V2 loaded successfully\n")
|
| 96 |
+
|
| 97 |
+
except Exception as e:
|
| 98 |
+
print(f"β Error loading VLM2Vec-V2: {e}")
|
| 99 |
+
raise
|
| 100 |
+
|
| 101 |
+
def normalize_text(self, text: str) -> str:
|
| 102 |
+
"""Normalize text for better embeddings."""
|
| 103 |
+
# Remove excessive whitespace
|
| 104 |
+
text = re.sub(r'\s+', ' ', text)
|
| 105 |
+
|
| 106 |
+
# Remove page numbers
|
| 107 |
+
text = re.sub(r'Page \d+', '', text, flags=re.IGNORECASE)
|
| 108 |
+
|
| 109 |
+
# Normalize unicode
|
| 110 |
+
text = text.strip()
|
| 111 |
+
|
| 112 |
+
return text
|
| 113 |
+
|
| 114 |
+
def embed_documents(self, texts: List[str], add_instruction: bool = True) -> List[List[float]]:
|
| 115 |
+
"""Embed documents with instruction prefix and weighted mean pooling."""
|
| 116 |
+
embeddings = []
|
| 117 |
+
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
for text in texts:
|
| 120 |
+
try:
|
| 121 |
+
# β
NORMALIZE TEXT
|
| 122 |
+
clean_text = self.normalize_text(text)
|
| 123 |
+
|
| 124 |
+
# β
ADD INSTRUCTION PREFIX
|
| 125 |
+
if add_instruction:
|
| 126 |
+
prefixed_text = DOCUMENT_INSTRUCTION + clean_text
|
| 127 |
+
else:
|
| 128 |
+
prefixed_text = clean_text
|
| 129 |
+
|
| 130 |
+
inputs = self.tokenizer(
|
| 131 |
+
prefixed_text,
|
| 132 |
+
return_tensors="pt",
|
| 133 |
+
padding=True,
|
| 134 |
+
truncation=True,
|
| 135 |
+
max_length=min(self.tokenizer.model_max_length or 512, 2048)
|
| 136 |
+
).to(self.device)
|
| 137 |
+
|
| 138 |
+
outputs = self.model(**inputs, output_hidden_states=True)
|
| 139 |
+
|
| 140 |
+
if hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:
|
| 141 |
+
# β
WEIGHTED MEAN POOLING (ignores padding)
|
| 142 |
+
hidden_states = outputs.hidden_states[-1]
|
| 143 |
+
attention_mask = inputs['attention_mask'].unsqueeze(-1).float()
|
| 144 |
+
|
| 145 |
+
# Apply attention mask as weights
|
| 146 |
+
weighted_hidden_states = hidden_states * attention_mask
|
| 147 |
+
sum_embeddings = weighted_hidden_states.sum(dim=1)
|
| 148 |
+
sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9)
|
| 149 |
+
|
| 150 |
+
# Weighted mean
|
| 151 |
+
embedding = (sum_embeddings / sum_mask).squeeze()
|
| 152 |
+
else:
|
| 153 |
+
# Fallback to logits
|
| 154 |
+
attention_mask = inputs['attention_mask'].unsqueeze(-1).float()
|
| 155 |
+
weighted_logits = outputs.logits * attention_mask
|
| 156 |
+
sum_embeddings = weighted_logits.sum(dim=1)
|
| 157 |
+
sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9)
|
| 158 |
+
embedding = (sum_embeddings / sum_mask).squeeze()
|
| 159 |
+
|
| 160 |
+
embeddings.append(embedding.cpu().numpy().tolist())
|
| 161 |
+
|
| 162 |
+
except Exception as e:
|
| 163 |
+
print(f" β CRITICAL: Failed to embed text: {e}")
|
| 164 |
+
print(f" Text preview: {text[:100]}")
|
| 165 |
+
raise RuntimeError(f"Embedding failed for text: {text[:50]}...") from e
|
| 166 |
+
|
| 167 |
+
return embeddings
|
| 168 |
+
|
| 169 |
+
def embed_query(self, text: str) -> List[float]:
|
| 170 |
+
"""Embed query with query-specific instruction."""
|
| 171 |
+
# β
DIFFERENT INSTRUCTION FOR QUERIES
|
| 172 |
+
clean_text = self.normalize_text(text)
|
| 173 |
+
prefixed_text = QUERY_INSTRUCTION + clean_text
|
| 174 |
+
|
| 175 |
+
# Don't add document instruction again
|
| 176 |
+
return self.embed_documents([prefixed_text], add_instruction=False)[0]
|
| 177 |
+
|
| 178 |
+
def embed_image(self, image_path: str, prompt: str = "Technical diagram") -> Optional[List[float]]:
|
| 179 |
+
"""Embed image with Qwen2-VL proper format."""
|
| 180 |
+
try:
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
image = Image.open(image_path).convert('RGB')
|
| 183 |
+
|
| 184 |
+
# β
QWEN2-VL CORRECT FORMAT
|
| 185 |
+
messages = [
|
| 186 |
+
{
|
| 187 |
+
"role": "user",
|
| 188 |
+
"content": [
|
| 189 |
+
{"type": "image", "image": image},
|
| 190 |
+
{"type": "text", "text": prompt}
|
| 191 |
+
]
|
| 192 |
+
}
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
# Apply chat template
|
| 196 |
+
text = self.processor.apply_chat_template(
|
| 197 |
+
messages,
|
| 198 |
+
tokenize=False,
|
| 199 |
+
add_generation_prompt=True
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Process with both text and images
|
| 203 |
+
inputs = self.processor(
|
| 204 |
+
text=[text],
|
| 205 |
+
images=[image],
|
| 206 |
+
return_tensors="pt",
|
| 207 |
+
padding=True
|
| 208 |
+
).to(self.device)
|
| 209 |
+
|
| 210 |
+
outputs = self.model(**inputs, output_hidden_states=True)
|
| 211 |
+
|
| 212 |
+
if hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:
|
| 213 |
+
hidden_states = outputs.hidden_states[-1]
|
| 214 |
+
|
| 215 |
+
# Use weighted mean pooling
|
| 216 |
+
if 'attention_mask' in inputs:
|
| 217 |
+
attention_mask = inputs['attention_mask'].unsqueeze(-1).float()
|
| 218 |
+
weighted_hidden_states = hidden_states * attention_mask
|
| 219 |
+
sum_embeddings = weighted_hidden_states.sum(dim=1)
|
| 220 |
+
sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9)
|
| 221 |
+
embedding = (sum_embeddings / sum_mask).squeeze()
|
| 222 |
+
else:
|
| 223 |
+
embedding = hidden_states.mean(dim=1).squeeze()
|
| 224 |
+
else:
|
| 225 |
+
# Fallback to pooler output if available
|
| 226 |
+
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
|
| 227 |
+
embedding = outputs.pooler_output.squeeze()
|
| 228 |
+
else:
|
| 229 |
+
return None
|
| 230 |
+
|
| 231 |
+
return embedding.cpu().numpy().tolist()
|
| 232 |
+
|
| 233 |
+
except Exception as e:
|
| 234 |
+
print(f" β οΈ Failed to embed image {Path(image_path).name}: {str(e)[:100]}")
|
| 235 |
+
return None
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# ========================================
|
| 239 |
+
# π QUERY PREPROCESSING
|
| 240 |
+
# ========================================
|
| 241 |
+
def preprocess_query(query: str) -> str:
|
| 242 |
+
"""Preprocess query by expanding abbreviations."""
|
| 243 |
+
|
| 244 |
+
abbreviations = {
|
| 245 |
+
r'\bh2s\b': 'hydrogen sulfide',
|
| 246 |
+
r'\bppm\b': 'parts per million',
|
| 247 |
+
r'\bppe\b': 'personal protective equipment',
|
| 248 |
+
r'\bscba\b': 'self contained breathing apparatus',
|
| 249 |
+
r'\blel\b': 'lower explosive limit',
|
| 250 |
+
r'\bhel\b': 'higher explosive limit',
|
| 251 |
+
r'\buel\b': 'upper explosive limit'
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
query_lower = query.lower()
|
| 255 |
+
for abbr, full in abbreviations.items():
|
| 256 |
+
query_lower = re.sub(abbr, full, query_lower)
|
| 257 |
+
|
| 258 |
+
# Remove excessive punctuation
|
| 259 |
+
query_lower = re.sub(r'[?!]+$', '', query_lower)
|
| 260 |
+
|
| 261 |
+
# Clean extra spaces
|
| 262 |
+
query_lower = re.sub(r'\s+', ' ', query_lower).strip()
|
| 263 |
+
|
| 264 |
+
return query_lower
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# ========================================
|
| 268 |
+
# π TABLE EXTRACTION
|
| 269 |
+
# ========================================
|
| 270 |
+
def is_table_of_contents_header(df, page_num):
|
| 271 |
+
"""Detect TOC by checking first row for keywords."""
|
| 272 |
+
if len(df) == 0 or page_num > 15:
|
| 273 |
+
return False
|
| 274 |
+
|
| 275 |
+
# Check first row (headers)
|
| 276 |
+
first_row = ' '.join(df.iloc[0].astype(str)).lower()
|
| 277 |
+
|
| 278 |
+
# TOC keywords in your images
|
| 279 |
+
toc_keywords = ['section', 'subsection', 'description', 'page no', 'page number', 'contents']
|
| 280 |
+
|
| 281 |
+
# If at least 2 keywords match, it's TOC
|
| 282 |
+
keyword_count = sum(1 for keyword in toc_keywords if keyword in first_row)
|
| 283 |
+
|
| 284 |
+
return keyword_count >= 2
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def looks_like_toc_data(df):
|
| 288 |
+
"""Check if table data looks like TOC (section numbers + page numbers)."""
|
| 289 |
+
if len(df) < 2 or len(df.columns) < 2:
|
| 290 |
+
return False
|
| 291 |
+
|
| 292 |
+
# Check last column: should be mostly page numbers (182-246 range in your case)
|
| 293 |
+
last_col = df.iloc[1:, -1].astype(str) # Skip header row
|
| 294 |
+
numeric_count = sum(val.strip().isdigit() and 50 < int(val.strip()) < 300
|
| 295 |
+
for val in last_col if val.strip().isdigit())
|
| 296 |
+
|
| 297 |
+
if len(last_col) > 0 and numeric_count / len(last_col) > 0.7:
|
| 298 |
+
# Check first column: should have section numbers like "10.1", "10.2"
|
| 299 |
+
first_col = df.iloc[1:, 0].astype(str)
|
| 300 |
+
section_pattern = sum(1 for val in first_col
|
| 301 |
+
if re.match(r'^\d+\.?\d*$', val.strip()))
|
| 302 |
+
|
| 303 |
+
if section_pattern / len(first_col) > 0.5:
|
| 304 |
+
return True
|
| 305 |
+
|
| 306 |
+
return False
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def extract_tables_from_pdf(pdf_path: str) -> List[Document]:
|
| 310 |
+
"""Extract bordered tables with smart TOC detection."""
|
| 311 |
+
chunks = []
|
| 312 |
+
|
| 313 |
+
try:
|
| 314 |
+
lattice_tables = camelot.read_pdf(
|
| 315 |
+
pdf_path,
|
| 316 |
+
pages='all',
|
| 317 |
+
flavor='lattice', # Only bordered tables
|
| 318 |
+
suppress_stdout=True
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
all_tables = list(lattice_tables)
|
| 322 |
+
seen_tables = set()
|
| 323 |
+
|
| 324 |
+
# Track TOC state
|
| 325 |
+
in_toc_section = False
|
| 326 |
+
toc_start_page = None
|
| 327 |
+
|
| 328 |
+
print(f" π Found {len(all_tables)} bordered tables")
|
| 329 |
+
|
| 330 |
+
for table in all_tables:
|
| 331 |
+
df = table.df
|
| 332 |
+
current_page = table.page
|
| 333 |
+
|
| 334 |
+
# Unique ID
|
| 335 |
+
table_id = (current_page, tuple(df.iloc[0].tolist()) if len(df) > 0 else ())
|
| 336 |
+
if table_id in seen_tables:
|
| 337 |
+
continue
|
| 338 |
+
seen_tables.add(table_id)
|
| 339 |
+
|
| 340 |
+
# Skip first 5 pages (title pages)
|
| 341 |
+
if current_page <= 5:
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
# Basic validation
|
| 345 |
+
if len(df.columns) < 2 or len(df) < 3 or table.accuracy < 80:
|
| 346 |
+
continue
|
| 347 |
+
|
| 348 |
+
# β
Detect TOC start (page with header row)
|
| 349 |
+
if not in_toc_section and is_table_of_contents_header(df, current_page):
|
| 350 |
+
in_toc_section = True
|
| 351 |
+
toc_start_page = current_page
|
| 352 |
+
print(f" π TOC detected at page {current_page}")
|
| 353 |
+
continue
|
| 354 |
+
|
| 355 |
+
# β
If we're in TOC section, check if this continues the pattern
|
| 356 |
+
if in_toc_section:
|
| 357 |
+
if looks_like_toc_data(df):
|
| 358 |
+
print(f" βοΈ Skipping TOC continuation on page {current_page}")
|
| 359 |
+
continue
|
| 360 |
+
else:
|
| 361 |
+
# TOC ended, resume normal extraction
|
| 362 |
+
print(f" β
TOC ended, found real table on page {current_page}")
|
| 363 |
+
in_toc_section = False
|
| 364 |
+
|
| 365 |
+
# Extract valid table
|
| 366 |
+
table_text = table_to_natural_language_enhanced(table)
|
| 367 |
+
|
| 368 |
+
if table_text.strip():
|
| 369 |
+
chunks.append(Document(
|
| 370 |
+
page_content=table_text,
|
| 371 |
+
metadata={
|
| 372 |
+
"source": os.path.basename(pdf_path),
|
| 373 |
+
"page": current_page,
|
| 374 |
+
"heading": "Table Data",
|
| 375 |
+
"type": "table",
|
| 376 |
+
"table_accuracy": table.accuracy
|
| 377 |
+
}
|
| 378 |
+
))
|
| 379 |
+
|
| 380 |
+
print(f" β
Extracted {len(chunks)} valid tables (after TOC filtering)")
|
| 381 |
+
|
| 382 |
+
except Exception as e:
|
| 383 |
+
print(f"β οΈ Table extraction failed: {e}")
|
| 384 |
+
|
| 385 |
+
finally:
|
| 386 |
+
try:
|
| 387 |
+
del lattice_tables
|
| 388 |
+
del all_tables
|
| 389 |
+
gc.collect()
|
| 390 |
+
time.sleep(0.1)
|
| 391 |
+
except:
|
| 392 |
+
pass
|
| 393 |
+
|
| 394 |
+
return chunks
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def table_to_natural_language_enhanced(table) -> str:
|
| 399 |
+
"""Enhanced table-to-natural-language conversion."""
|
| 400 |
+
df = table.df
|
| 401 |
+
|
| 402 |
+
if len(df) < 2:
|
| 403 |
+
return ""
|
| 404 |
+
|
| 405 |
+
headers = [str(h).strip() for h in df.iloc[0].astype(str).tolist()]
|
| 406 |
+
headers = [h if h and h.lower() not in ['', 'nan', 'none'] else f"Column_{i}"
|
| 407 |
+
for i, h in enumerate(headers)]
|
| 408 |
+
|
| 409 |
+
descriptions = []
|
| 410 |
+
|
| 411 |
+
for idx in range(1, len(df)):
|
| 412 |
+
row = [str(cell).strip() for cell in df.iloc[idx].astype(str).tolist()]
|
| 413 |
+
|
| 414 |
+
if not any(cell and cell.lower() not in ['', 'nan', 'none'] for cell in row):
|
| 415 |
+
continue
|
| 416 |
+
|
| 417 |
+
if len(row) > 0 and row[0] and row[0].lower() not in ['', 'nan', 'none']:
|
| 418 |
+
sentence_parts = []
|
| 419 |
+
|
| 420 |
+
for i in range(1, min(len(row), len(headers))):
|
| 421 |
+
if row[i] and row[i].lower() not in ['', 'nan', 'none']:
|
| 422 |
+
sentence_parts.append(f"{headers[i]}: {row[i]}")
|
| 423 |
+
|
| 424 |
+
if sentence_parts:
|
| 425 |
+
descriptions.append(f"{row[0]} has {', '.join(sentence_parts)}.")
|
| 426 |
+
else:
|
| 427 |
+
descriptions.append(f"{row[0]}.")
|
| 428 |
+
|
| 429 |
+
return "\n".join(descriptions)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def extract_tables_with_ocr(pdf_path: str, page_num: int) -> List[Dict]:
|
| 433 |
+
"""OCR fallback for image-based PDFs."""
|
| 434 |
+
try:
|
| 435 |
+
images = convert_from_path(pdf_path, first_page=page_num, last_page=page_num)
|
| 436 |
+
|
| 437 |
+
if not images:
|
| 438 |
+
return []
|
| 439 |
+
|
| 440 |
+
ocr_text = pytesseract.image_to_string(images[0])
|
| 441 |
+
lines = ocr_text.split('\n')
|
| 442 |
+
table_lines = []
|
| 443 |
+
|
| 444 |
+
for line in lines:
|
| 445 |
+
if re.search(r'\s{2,}', line) or '\t' in line:
|
| 446 |
+
table_lines.append(line)
|
| 447 |
+
|
| 448 |
+
if len(table_lines) > 2:
|
| 449 |
+
return [{
|
| 450 |
+
"text": "\n".join(table_lines),
|
| 451 |
+
"page": page_num,
|
| 452 |
+
"method": "ocr"
|
| 453 |
+
}]
|
| 454 |
+
|
| 455 |
+
return []
|
| 456 |
+
|
| 457 |
+
except Exception as e:
|
| 458 |
+
return []
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def get_table_regions(pdf_path: str) -> Dict[int, List[tuple]]:
|
| 462 |
+
"""Get bounding boxes using BOTH lattice and stream methods."""
|
| 463 |
+
table_regions = {}
|
| 464 |
+
|
| 465 |
+
try:
|
| 466 |
+
lattice_tables = camelot.read_pdf(pdf_path, pages='all', flavor='lattice', suppress_stdout=True)
|
| 467 |
+
stream_tables = camelot.read_pdf(pdf_path, pages='all', flavor='stream', suppress_stdout=True)
|
| 468 |
+
|
| 469 |
+
all_tables = list(lattice_tables) + list(stream_tables)
|
| 470 |
+
|
| 471 |
+
for table in all_tables:
|
| 472 |
+
page = table.page
|
| 473 |
+
|
| 474 |
+
if is_table_of_contents_header(table.df, page):
|
| 475 |
+
continue
|
| 476 |
+
|
| 477 |
+
bbox = table._bbox
|
| 478 |
+
|
| 479 |
+
if page not in table_regions:
|
| 480 |
+
table_regions[page] = []
|
| 481 |
+
|
| 482 |
+
if bbox not in table_regions[page]:
|
| 483 |
+
table_regions[page].append(bbox)
|
| 484 |
+
|
| 485 |
+
except Exception as e:
|
| 486 |
+
pass
|
| 487 |
+
|
| 488 |
+
return table_regions
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
# ========================================
|
| 494 |
+
# πΌοΈ IMAGE EXTRACTION
|
| 495 |
+
# ========================================
|
| 496 |
+
def extract_images_from_pdf(pdf_path: str, output_dir: str) -> List[Dict]:
|
| 497 |
+
"""Extract images from PDF."""
|
| 498 |
+
doc = fitz.open(pdf_path)
|
| 499 |
+
image_data = []
|
| 500 |
+
|
| 501 |
+
for page_num in range(len(doc)):
|
| 502 |
+
page = doc[page_num]
|
| 503 |
+
images = page.get_images()
|
| 504 |
+
|
| 505 |
+
for img_index, img in enumerate(images):
|
| 506 |
+
try:
|
| 507 |
+
xref = img[0]
|
| 508 |
+
base_image = doc.extract_image(xref)
|
| 509 |
+
image_bytes = base_image["image"]
|
| 510 |
+
|
| 511 |
+
if len(image_bytes) < 10000:
|
| 512 |
+
continue
|
| 513 |
+
|
| 514 |
+
image_filename = f"{Path(pdf_path).stem}_p{page_num+1}_img{img_index+1}.png"
|
| 515 |
+
image_path = os.path.join(output_dir, image_filename)
|
| 516 |
+
|
| 517 |
+
with open(image_path, "wb") as img_file:
|
| 518 |
+
img_file.write(image_bytes)
|
| 519 |
+
|
| 520 |
+
image_data.append({
|
| 521 |
+
"path": image_path,
|
| 522 |
+
"page": page_num + 1,
|
| 523 |
+
"source": os.path.basename(pdf_path),
|
| 524 |
+
"type": "image"
|
| 525 |
+
})
|
| 526 |
+
|
| 527 |
+
except Exception as e:
|
| 528 |
+
continue
|
| 529 |
+
|
| 530 |
+
doc.close()
|
| 531 |
+
return image_data
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
# ========================================
|
| 535 |
+
# π TEXT EXTRACTION WITH OVERLAPPING CHUNKS
|
| 536 |
+
# ========================================
|
| 537 |
+
def is_bold_text(span):
|
| 538 |
+
return "bold" in span['font'].lower() or (span['flags'] & 2**4)
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def is_likely_heading(text, font_size, is_bold, avg_font_size):
|
| 542 |
+
if not is_bold:
|
| 543 |
+
return False
|
| 544 |
+
text = text.strip()
|
| 545 |
+
if len(text) > 100 or len(text) < 3:
|
| 546 |
+
return False
|
| 547 |
+
if font_size > avg_font_size * 1.1:
|
| 548 |
+
return True
|
| 549 |
+
if text.isupper() or re.match(r'^\d+\.?\d*\s+[A-Z]', text):
|
| 550 |
+
return True
|
| 551 |
+
return False
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def is_inside_table(block_bbox, table_bboxes):
|
| 555 |
+
"""Check if text block overlaps with table region."""
|
| 556 |
+
bx1, by1, bx2, by2 = block_bbox
|
| 557 |
+
|
| 558 |
+
for table_bbox in table_bboxes:
|
| 559 |
+
tx1, ty1, tx2, ty2 = table_bbox
|
| 560 |
+
|
| 561 |
+
if not (bx2 < tx1 or bx1 > tx2 or by2 < ty1 or by1 > ty2):
|
| 562 |
+
return True
|
| 563 |
+
|
| 564 |
+
return False
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def split_text_with_overlap(text: str, heading: str, source: str, page: int,
|
| 568 |
+
chunk_size: int = CHUNK_SIZE, overlap: int = OVERLAP) -> List[Document]:
|
| 569 |
+
"""Split text with overlap and heading context."""
|
| 570 |
+
|
| 571 |
+
words = text.split()
|
| 572 |
+
|
| 573 |
+
if len(words) <= chunk_size:
|
| 574 |
+
# β
ADD HEADING CONTEXT
|
| 575 |
+
content_with_context = f"Section: {heading}\n\n{text}"
|
| 576 |
+
|
| 577 |
+
return [Document(
|
| 578 |
+
page_content=content_with_context,
|
| 579 |
+
metadata={
|
| 580 |
+
"source": source,
|
| 581 |
+
"page": page,
|
| 582 |
+
"heading": heading,
|
| 583 |
+
"type": "text",
|
| 584 |
+
"parent_text": text,
|
| 585 |
+
"chunk_index": 0,
|
| 586 |
+
"total_chunks": 1
|
| 587 |
+
}
|
| 588 |
+
)]
|
| 589 |
+
|
| 590 |
+
chunks = []
|
| 591 |
+
chunk_index = 0
|
| 592 |
+
|
| 593 |
+
for i in range(0, len(words), chunk_size - overlap):
|
| 594 |
+
chunk_words = words[i:i + chunk_size]
|
| 595 |
+
|
| 596 |
+
if len(chunk_words) < MIN_CHUNK_SIZE and len(chunks) > 0:
|
| 597 |
+
break
|
| 598 |
+
|
| 599 |
+
chunk_text = " ".join(chunk_words)
|
| 600 |
+
|
| 601 |
+
# β
ADD HEADING CONTEXT TO EACH CHUNK
|
| 602 |
+
content_with_context = f"Section: {heading}\n\n{chunk_text}"
|
| 603 |
+
|
| 604 |
+
chunks.append(Document(
|
| 605 |
+
page_content=content_with_context,
|
| 606 |
+
metadata={
|
| 607 |
+
"source": source,
|
| 608 |
+
"page": page,
|
| 609 |
+
"heading": heading,
|
| 610 |
+
"type": "text",
|
| 611 |
+
"parent_text": text,
|
| 612 |
+
"chunk_index": chunk_index,
|
| 613 |
+
"start_word": i,
|
| 614 |
+
"end_word": i + len(chunk_words)
|
| 615 |
+
}
|
| 616 |
+
))
|
| 617 |
+
|
| 618 |
+
chunk_index += 1
|
| 619 |
+
|
| 620 |
+
for chunk in chunks:
|
| 621 |
+
chunk.metadata["total_chunks"] = len(chunks)
|
| 622 |
+
|
| 623 |
+
return chunks
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
def extract_text_chunks_with_overlap(pdf_path: str, table_regions: Dict[int, List[tuple]]) -> List[Document]:
|
| 627 |
+
"""Extract text with overlapping chunks."""
|
| 628 |
+
doc = fitz.open(pdf_path)
|
| 629 |
+
|
| 630 |
+
all_font_sizes = []
|
| 631 |
+
for page_num in range(len(doc)):
|
| 632 |
+
page = doc[page_num]
|
| 633 |
+
blocks = page.get_text("dict")["blocks"]
|
| 634 |
+
for block in blocks:
|
| 635 |
+
if "lines" in block:
|
| 636 |
+
for line in block["lines"]:
|
| 637 |
+
for span in line["spans"]:
|
| 638 |
+
all_font_sizes.append(span["size"])
|
| 639 |
+
|
| 640 |
+
avg_font_size = sum(all_font_sizes) / len(all_font_sizes) if all_font_sizes else 12
|
| 641 |
+
|
| 642 |
+
sections = []
|
| 643 |
+
current_section = ""
|
| 644 |
+
current_heading = "Introduction"
|
| 645 |
+
current_page = 1
|
| 646 |
+
|
| 647 |
+
for page_num in range(len(doc)):
|
| 648 |
+
page = doc[page_num]
|
| 649 |
+
blocks = page.get_text("dict")["blocks"]
|
| 650 |
+
|
| 651 |
+
page_tables = table_regions.get(page_num + 1, [])
|
| 652 |
+
|
| 653 |
+
for block in blocks:
|
| 654 |
+
if "lines" not in block:
|
| 655 |
+
continue
|
| 656 |
+
|
| 657 |
+
block_bbox = block.get("bbox", (0, 0, 0, 0))
|
| 658 |
+
if is_inside_table(block_bbox, page_tables):
|
| 659 |
+
continue
|
| 660 |
+
|
| 661 |
+
for line in block["lines"]:
|
| 662 |
+
line_text = ""
|
| 663 |
+
line_is_bold = False
|
| 664 |
+
line_font_size = 0
|
| 665 |
+
|
| 666 |
+
for span in line["spans"]:
|
| 667 |
+
line_text += span["text"]
|
| 668 |
+
if is_bold_text(span):
|
| 669 |
+
line_is_bold = True
|
| 670 |
+
line_font_size = max(line_font_size, span["size"])
|
| 671 |
+
|
| 672 |
+
line_text = line_text.strip()
|
| 673 |
+
if not line_text:
|
| 674 |
+
continue
|
| 675 |
+
|
| 676 |
+
if is_likely_heading(line_text, line_font_size, line_is_bold, avg_font_size):
|
| 677 |
+
if current_section.strip():
|
| 678 |
+
sections.append({
|
| 679 |
+
"text": current_section.strip(),
|
| 680 |
+
"heading": current_heading,
|
| 681 |
+
"page": current_page,
|
| 682 |
+
"source": os.path.basename(pdf_path)
|
| 683 |
+
})
|
| 684 |
+
|
| 685 |
+
current_heading = line_text
|
| 686 |
+
current_section = ""
|
| 687 |
+
current_page = page_num + 1
|
| 688 |
+
else:
|
| 689 |
+
current_section += line_text + " "
|
| 690 |
+
|
| 691 |
+
if current_section.strip():
|
| 692 |
+
sections.append({
|
| 693 |
+
"text": current_section.strip(),
|
| 694 |
+
"heading": current_heading,
|
| 695 |
+
"page": current_page,
|
| 696 |
+
"source": os.path.basename(pdf_path)
|
| 697 |
+
})
|
| 698 |
+
|
| 699 |
+
doc.close()
|
| 700 |
+
|
| 701 |
+
all_chunks = []
|
| 702 |
+
|
| 703 |
+
for section in sections:
|
| 704 |
+
chunks = split_text_with_overlap(
|
| 705 |
+
text=section['text'],
|
| 706 |
+
heading=section['heading'],
|
| 707 |
+
source=section['source'],
|
| 708 |
+
page=section['page'],
|
| 709 |
+
chunk_size=CHUNK_SIZE,
|
| 710 |
+
overlap=OVERLAP
|
| 711 |
+
)
|
| 712 |
+
all_chunks.extend(chunks)
|
| 713 |
+
|
| 714 |
+
return all_chunks
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
# ========================================
|
| 718 |
+
# π COMBINED EXTRACTION
|
| 719 |
+
# ========================================
|
| 720 |
+
def extract_all_content_from_pdf(pdf_path: str) -> Tuple[List[Document], List[Dict]]:
|
| 721 |
+
"""Extract text, tables, and images."""
|
| 722 |
+
|
| 723 |
+
print(f" π Extracting tables...")
|
| 724 |
+
table_regions = get_table_regions(pdf_path)
|
| 725 |
+
table_chunks = extract_tables_from_pdf(pdf_path)
|
| 726 |
+
print(f" β
{len(table_chunks)} table chunks")
|
| 727 |
+
|
| 728 |
+
print(f" π Extracting text...")
|
| 729 |
+
text_chunks = extract_text_chunks_with_overlap(pdf_path, table_regions)
|
| 730 |
+
print(f" β
{len(text_chunks)} text chunks")
|
| 731 |
+
|
| 732 |
+
print(f" πΌοΈ Extracting images...")
|
| 733 |
+
images = extract_images_from_pdf(pdf_path, IMAGE_OUTPUT_DIR)
|
| 734 |
+
print(f" β
{len(images)} images")
|
| 735 |
+
|
| 736 |
+
all_chunks = text_chunks + table_chunks
|
| 737 |
+
|
| 738 |
+
return all_chunks, images
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
# ========================================
|
| 742 |
+
# ποΈ BUILD FAISS INDEX WITH STREAMING
|
| 743 |
+
# ========================================
|
| 744 |
+
# Replace the HybridRetriever class and related functions with this optimized version:
|
| 745 |
+
|
| 746 |
+
# ========================================
|
| 747 |
+
# ποΈ BUILD FAISS INDEX WITH BM25
|
| 748 |
+
# ========================================
|
| 749 |
+
def build_multimodal_faiss_streaming(pdf_files: List[str], embedding_model: VLM2VecEmbeddings):
|
| 750 |
+
"""Build FAISS index with streaming and BM25."""
|
| 751 |
+
|
| 752 |
+
index_hash_file = f"{FAISS_INDEX_PATH}/index_hash.txt"
|
| 753 |
+
current_hash = hashlib.md5("".join(sorted(pdf_files)).encode()).hexdigest()
|
| 754 |
+
|
| 755 |
+
if os.path.exists(index_hash_file):
|
| 756 |
+
with open(index_hash_file, 'r') as f:
|
| 757 |
+
existing_hash = f.read().strip()
|
| 758 |
+
|
| 759 |
+
if existing_hash == current_hash:
|
| 760 |
+
print("β οΈ Index already exists for these PDFs!")
|
| 761 |
+
response = input(" Rebuild anyway? (yes/no): ").strip().lower()
|
| 762 |
+
if response != 'yes':
|
| 763 |
+
return None, []
|
| 764 |
+
|
| 765 |
+
all_texts = []
|
| 766 |
+
all_image_paths = []
|
| 767 |
+
|
| 768 |
+
print("\nπ Processing PDFs...\n")
|
| 769 |
+
|
| 770 |
+
for pdf_file in pdf_files:
|
| 771 |
+
print(f"π Processing: {Path(pdf_file).name}")
|
| 772 |
+
|
| 773 |
+
try:
|
| 774 |
+
text_chunks, images = extract_all_content_from_pdf(pdf_file)
|
| 775 |
+
|
| 776 |
+
all_texts.extend(text_chunks)
|
| 777 |
+
all_image_paths.extend(images)
|
| 778 |
+
|
| 779 |
+
except Exception as e:
|
| 780 |
+
print(f" β Error: {e}")
|
| 781 |
+
continue
|
| 782 |
+
|
| 783 |
+
print()
|
| 784 |
+
|
| 785 |
+
print(f"β
Total chunks: {len(all_texts)}")
|
| 786 |
+
print(f"β
Total images: {len(all_image_paths)}\n")
|
| 787 |
+
|
| 788 |
+
if len(all_texts) == 0:
|
| 789 |
+
print("β No content extracted!")
|
| 790 |
+
return None, []
|
| 791 |
+
|
| 792 |
+
# Build text index
|
| 793 |
+
print("π Generating text embeddings...\n")
|
| 794 |
+
|
| 795 |
+
text_index = None
|
| 796 |
+
batch_size = 10
|
| 797 |
+
|
| 798 |
+
for i in range(0, len(all_texts), batch_size):
|
| 799 |
+
batch = all_texts[i:i+batch_size]
|
| 800 |
+
batch_contents = [doc.page_content for doc in batch]
|
| 801 |
+
|
| 802 |
+
try:
|
| 803 |
+
batch_embeddings = embedding_model.embed_documents(batch_contents, add_instruction=True)
|
| 804 |
+
batch_embeddings_np = np.array(batch_embeddings).astype('float32')
|
| 805 |
+
|
| 806 |
+
if text_index is None:
|
| 807 |
+
dimension = batch_embeddings_np.shape[1]
|
| 808 |
+
text_index = faiss.IndexFlatIP(dimension)
|
| 809 |
+
print(f" Text embedding dimension: {dimension}")
|
| 810 |
+
|
| 811 |
+
faiss.normalize_L2(batch_embeddings_np)
|
| 812 |
+
text_index.add(batch_embeddings_np)
|
| 813 |
+
|
| 814 |
+
if (i // batch_size + 1) % 5 == 0:
|
| 815 |
+
print(f" Progress: {i + len(batch)}/{len(all_texts)}")
|
| 816 |
+
|
| 817 |
+
except Exception as e:
|
| 818 |
+
print(f" β Error: {e}")
|
| 819 |
+
raise
|
| 820 |
+
|
| 821 |
+
print(f" β
Complete")
|
| 822 |
+
|
| 823 |
+
# Save FAISS index
|
| 824 |
+
faiss.write_index(text_index, f"{FAISS_INDEX_PATH}/text_index.faiss")
|
| 825 |
+
|
| 826 |
+
# Save documents
|
| 827 |
+
with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "wb") as f:
|
| 828 |
+
pickle.dump(all_texts, f)
|
| 829 |
+
|
| 830 |
+
# β
BUILD AND SAVE BM25 INDEX
|
| 831 |
+
print("\nπ Building BM25 index for keyword search...")
|
| 832 |
+
tokenized_docs = [doc.page_content.lower().split() for doc in all_texts]
|
| 833 |
+
bm25_index = BM25Okapi(tokenized_docs,k1=1.3, b=0.65)
|
| 834 |
+
|
| 835 |
+
with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "wb") as f:
|
| 836 |
+
pickle.dump(bm25_index, f)
|
| 837 |
+
|
| 838 |
+
print(" β
BM25 index saved")
|
| 839 |
+
|
| 840 |
+
# Build image index
|
| 841 |
+
if len(all_image_paths) > 0:
|
| 842 |
+
print(f"\nπΌοΈ Embedding images...")
|
| 843 |
+
|
| 844 |
+
image_index = None
|
| 845 |
+
successful_images = []
|
| 846 |
+
|
| 847 |
+
for idx, img_data in enumerate(all_image_paths):
|
| 848 |
+
img_embedding = embedding_model.embed_image(img_data["path"])
|
| 849 |
+
|
| 850 |
+
if img_embedding is None:
|
| 851 |
+
continue
|
| 852 |
+
|
| 853 |
+
img_embedding_np = np.array([img_embedding]).astype('float32')
|
| 854 |
+
|
| 855 |
+
if image_index is None:
|
| 856 |
+
dimension = img_embedding_np.shape[1]
|
| 857 |
+
image_index = faiss.IndexFlatIP(dimension)
|
| 858 |
+
print(f" Image dimension: {dimension}")
|
| 859 |
+
|
| 860 |
+
faiss.normalize_L2(img_embedding_np)
|
| 861 |
+
image_index.add(img_embedding_np)
|
| 862 |
+
successful_images.append(img_data)
|
| 863 |
+
|
| 864 |
+
if (len(successful_images)) % 10 == 0:
|
| 865 |
+
print(f" Progress: {len(successful_images)}/{len(all_image_paths)}")
|
| 866 |
+
|
| 867 |
+
print(f" β
{len(successful_images)} images embedded")
|
| 868 |
+
|
| 869 |
+
if image_index is not None and len(successful_images) > 0:
|
| 870 |
+
faiss.write_index(image_index, f"{FAISS_INDEX_PATH}/image_index.faiss")
|
| 871 |
+
|
| 872 |
+
with open(f"{FAISS_INDEX_PATH}/image_documents.pkl", "wb") as f:
|
| 873 |
+
pickle.dump(successful_images, f)
|
| 874 |
+
|
| 875 |
+
# Save hash
|
| 876 |
+
with open(index_hash_file, 'w') as f:
|
| 877 |
+
f.write(current_hash)
|
| 878 |
+
|
| 879 |
+
print(f"\nβ
Index saved: {FAISS_INDEX_PATH}\n")
|
| 880 |
+
|
| 881 |
+
return text_index, all_texts
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
# ========================================
|
| 885 |
+
# π OPTIMIZED HYBRID SEARCH
|
| 886 |
+
# ========================================
|
| 887 |
+
# ========================================
|
| 888 |
+
# π QUERY WITH BM25 ONLY
|
| 889 |
+
# ========================================
|
| 890 |
+
def query_with_bm25(query: str, k_text: int = 5, k_images: int = 3):
|
| 891 |
+
"""Query using BM25 keyword search only."""
|
| 892 |
+
|
| 893 |
+
# β
PREPROCESS QUERY
|
| 894 |
+
processed_query = preprocess_query(query)
|
| 895 |
+
print(f" π Processed: {processed_query}")
|
| 896 |
+
|
| 897 |
+
# Load documents
|
| 898 |
+
with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "rb") as f:
|
| 899 |
+
text_docs = pickle.load(f)
|
| 900 |
+
|
| 901 |
+
# β
LOAD BM25 INDEX
|
| 902 |
+
try:
|
| 903 |
+
with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "rb") as f:
|
| 904 |
+
bm25_index = pickle.load(f)
|
| 905 |
+
except FileNotFoundError:
|
| 906 |
+
print(" β οΈ BM25 index not found, building on-the-fly...")
|
| 907 |
+
tokenized_docs = [doc.page_content.lower().split() for doc in text_docs]
|
| 908 |
+
bm25_index = BM25Okapi(tokenized_docs)
|
| 909 |
+
|
| 910 |
+
# BM25 SEARCH ONLY
|
| 911 |
+
tokenized_query = processed_query.lower().split()
|
| 912 |
+
bm25_scores = bm25_index.get_scores(tokenized_query)
|
| 913 |
+
|
| 914 |
+
# Get top k results
|
| 915 |
+
top_indices = np.argsort(bm25_scores)[::-1][:k_text]
|
| 916 |
+
|
| 917 |
+
text_results = []
|
| 918 |
+
relevant_pages = set()
|
| 919 |
+
|
| 920 |
+
for rank, idx in enumerate(top_indices, 1):
|
| 921 |
+
doc = text_docs[idx]
|
| 922 |
+
score = float(bm25_scores[idx])
|
| 923 |
+
|
| 924 |
+
text_results.append({
|
| 925 |
+
"document": doc,
|
| 926 |
+
"score": score,
|
| 927 |
+
"rank": rank,
|
| 928 |
+
"type": doc.metadata.get('type', 'text')
|
| 929 |
+
})
|
| 930 |
+
relevant_pages.add((doc.metadata.get('source'), doc.metadata.get('page')))
|
| 931 |
+
|
| 932 |
+
# Get images from relevant pages (not semantic search)
|
| 933 |
+
relevant_images = []
|
| 934 |
+
|
| 935 |
+
try:
|
| 936 |
+
image_docs_path = f"{FAISS_INDEX_PATH}/image_documents.pkl"
|
| 937 |
+
|
| 938 |
+
if os.path.exists(image_docs_path):
|
| 939 |
+
with open(image_docs_path, "rb") as f:
|
| 940 |
+
image_docs = pickle.load(f)
|
| 941 |
+
|
| 942 |
+
# Get images from same pages as top text results
|
| 943 |
+
for img_doc in image_docs:
|
| 944 |
+
img_page = (img_doc['source'], img_doc['page'])
|
| 945 |
+
if img_page in relevant_pages and len(relevant_images) < k_images:
|
| 946 |
+
relevant_images.append({
|
| 947 |
+
"path": img_doc['path'],
|
| 948 |
+
"source": img_doc['source'],
|
| 949 |
+
"page": img_doc['page'],
|
| 950 |
+
"type": "image",
|
| 951 |
+
"score": 0.0,
|
| 952 |
+
"rank": len(relevant_images) + 1,
|
| 953 |
+
"from_page": True
|
| 954 |
+
})
|
| 955 |
+
|
| 956 |
+
except Exception as e:
|
| 957 |
+
pass
|
| 958 |
+
|
| 959 |
+
return {
|
| 960 |
+
"text_results": text_results,
|
| 961 |
+
"images": relevant_images,
|
| 962 |
+
"query": query,
|
| 963 |
+
"processed_query": processed_query
|
| 964 |
+
}
|
| 965 |
+
|
| 966 |
+
|
| 967 |
+
# ========================================
|
| 968 |
+
# π DISPLAY RESULTS (BM25 ONLY)
|
| 969 |
+
# ========================================
|
| 970 |
+
def display_results_bm25(results: Dict):
|
| 971 |
+
"""Display BM25 results."""
|
| 972 |
+
|
| 973 |
+
print("\nπ TOP RESULTS (BM25 Keyword Search):\n")
|
| 974 |
+
|
| 975 |
+
for result in results['text_results']:
|
| 976 |
+
doc = result["document"]
|
| 977 |
+
print(f"[{result['rank']}] BM25 Score: {result['score']:.4f} | {doc.metadata.get('type', 'N/A')}")
|
| 978 |
+
print(f" π {doc.metadata.get('source')} - Page {doc.metadata.get('page')}")
|
| 979 |
+
print(f" π {doc.metadata.get('heading', 'N/A')[:60]}")
|
| 980 |
+
|
| 981 |
+
if 'total_chunks' in doc.metadata and doc.metadata.get('total_chunks', 1) > 1:
|
| 982 |
+
print(f" π Chunk {doc.metadata.get('chunk_index', 0)+1}/{doc.metadata.get('total_chunks')}")
|
| 983 |
+
|
| 984 |
+
print(f" π {doc.page_content[:200]}...")
|
| 985 |
+
print()
|
| 986 |
+
|
| 987 |
+
print("\nπΌοΈ IMAGES:\n")
|
| 988 |
+
if results['images']:
|
| 989 |
+
for img in results['images']:
|
| 990 |
+
print(f"[{img['rank']}] {img['source']} - Page {img['page']}")
|
| 991 |
+
print(f" {img['path']}\n")
|
| 992 |
+
else:
|
| 993 |
+
print(" No images found\n")
|
| 994 |
+
|
| 995 |
+
|
| 996 |
+
# ========================================
|
| 997 |
+
# π HYBRID SEARCH IMPLEMENTATION
|
| 998 |
+
# ========================================
|
| 999 |
+
|
| 1000 |
+
def normalize_scores(scores: np.ndarray) -> np.ndarray:
|
| 1001 |
+
"""Min-max normalization to 0-1 range."""
|
| 1002 |
+
if len(scores) == 0:
|
| 1003 |
+
return scores
|
| 1004 |
+
|
| 1005 |
+
min_score = np.min(scores)
|
| 1006 |
+
max_score = np.max(scores)
|
| 1007 |
+
|
| 1008 |
+
if max_score == min_score:
|
| 1009 |
+
return np.ones_like(scores)
|
| 1010 |
+
|
| 1011 |
+
return (scores - min_score) / (max_score - min_score)
|
| 1012 |
+
|
| 1013 |
+
|
| 1014 |
+
def query_with_hybrid(query: str, embedding_model: VLM2VecEmbeddings,
|
| 1015 |
+
k_text: int = 5, k_images: int = 3,
|
| 1016 |
+
dense_weight: float = DENSE_WEIGHT,
|
| 1017 |
+
sparse_weight: float = SPARSE_WEIGHT):
|
| 1018 |
+
|
| 1019 |
+
"""
|
| 1020 |
+
Hybrid search combining semantic (FAISS) and keyword (BM25) retrieval.
|
| 1021 |
+
"""
|
| 1022 |
+
|
| 1023 |
+
processed_query = preprocess_query(query)
|
| 1024 |
+
print(f" π Processed: {processed_query}")
|
| 1025 |
+
|
| 1026 |
+
with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "rb") as f:
|
| 1027 |
+
text_docs = pickle.load(f)
|
| 1028 |
+
|
| 1029 |
+
# SEMANTIC SEARCH
|
| 1030 |
+
print(f" π§ Running semantic search...")
|
| 1031 |
+
|
| 1032 |
+
try:
|
| 1033 |
+
text_index = faiss.read_index(f"{FAISS_INDEX_PATH}/text_index.faiss")
|
| 1034 |
+
|
| 1035 |
+
query_embedding = embedding_model.embed_query(processed_query)
|
| 1036 |
+
query_np = np.array([query_embedding]).astype('float32')
|
| 1037 |
+
faiss.normalize_L2(query_np)
|
| 1038 |
+
|
| 1039 |
+
k_retrieve = min(k_text * 3, len(text_docs))
|
| 1040 |
+
distances, indices = text_index.search(query_np, k_retrieve)
|
| 1041 |
+
|
| 1042 |
+
semantic_scores = distances[0]
|
| 1043 |
+
semantic_indices = indices[0]
|
| 1044 |
+
|
| 1045 |
+
print(f" β
Retrieved {len(semantic_indices)} semantic results")
|
| 1046 |
+
|
| 1047 |
+
except Exception as e:
|
| 1048 |
+
print(f" β οΈ Semantic search failed: {e}")
|
| 1049 |
+
semantic_scores = np.array([])
|
| 1050 |
+
semantic_indices = np.array([])
|
| 1051 |
+
|
| 1052 |
+
# BM25 SEARCH
|
| 1053 |
+
print(f" π€ Running BM25 keyword search...")
|
| 1054 |
+
|
| 1055 |
+
try:
|
| 1056 |
+
with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "rb") as f:
|
| 1057 |
+
bm25_index = pickle.load(f)
|
| 1058 |
+
except FileNotFoundError:
|
| 1059 |
+
tokenized_docs = [doc.page_content.lower().split() for doc in text_docs]
|
| 1060 |
+
bm25_index = BM25Okapi(tokenized_docs, k1=1.3, b=0.65)
|
| 1061 |
+
|
| 1062 |
+
tokenized_query = processed_query.lower().split()
|
| 1063 |
+
bm25_scores_all = bm25_index.get_scores(tokenized_query)
|
| 1064 |
+
|
| 1065 |
+
print(f" β
Scored {len(bm25_scores_all)} documents")
|
| 1066 |
+
|
| 1067 |
+
# SCORE FUSION
|
| 1068 |
+
print(f" βοΈ Fusing scores (semantic: {dense_weight}, BM25: {sparse_weight})...")
|
| 1069 |
+
|
| 1070 |
+
combined_scores = {}
|
| 1071 |
+
|
| 1072 |
+
if len(semantic_scores) > 0:
|
| 1073 |
+
semantic_scores_norm = normalize_scores(semantic_scores)
|
| 1074 |
+
|
| 1075 |
+
for idx, score in zip(semantic_indices, semantic_scores_norm):
|
| 1076 |
+
if idx < len(text_docs):
|
| 1077 |
+
combined_scores[idx] = dense_weight * score
|
| 1078 |
+
|
| 1079 |
+
bm25_scores_norm = normalize_scores(bm25_scores_all)
|
| 1080 |
+
|
| 1081 |
+
for idx, score in enumerate(bm25_scores_norm):
|
| 1082 |
+
if idx in combined_scores:
|
| 1083 |
+
combined_scores[idx] += sparse_weight * score
|
| 1084 |
+
else:
|
| 1085 |
+
combined_scores[idx] = sparse_weight * score
|
| 1086 |
+
|
| 1087 |
+
sorted_indices = sorted(combined_scores.keys(),
|
| 1088 |
+
key=lambda x: combined_scores[x],
|
| 1089 |
+
reverse=True)
|
| 1090 |
+
|
| 1091 |
+
top_indices = sorted_indices[:k_text]
|
| 1092 |
+
|
| 1093 |
+
print(f" β
Top {len(top_indices)} results selected")
|
| 1094 |
+
|
| 1095 |
+
# PREPARE RESULTS
|
| 1096 |
+
text_results = []
|
| 1097 |
+
relevant_pages = set()
|
| 1098 |
+
|
| 1099 |
+
for rank, idx in enumerate(top_indices, 1):
|
| 1100 |
+
doc = text_docs[idx]
|
| 1101 |
+
|
| 1102 |
+
semantic_score = semantic_scores_norm[np.where(semantic_indices == idx)[0][0]] if idx in semantic_indices else 0.0
|
| 1103 |
+
bm25_score = bm25_scores_norm[idx]
|
| 1104 |
+
combined_score = combined_scores[idx]
|
| 1105 |
+
|
| 1106 |
+
text_results.append({
|
| 1107 |
+
"document": doc,
|
| 1108 |
+
"score": combined_score,
|
| 1109 |
+
"semantic_score": float(semantic_score),
|
| 1110 |
+
"bm25_score": float(bm25_score),
|
| 1111 |
+
"rank": rank,
|
| 1112 |
+
"type": doc.metadata.get('type', 'text')
|
| 1113 |
+
})
|
| 1114 |
+
relevant_pages.add((doc.metadata.get('source'), doc.metadata.get('page')))
|
| 1115 |
+
|
| 1116 |
+
# GET IMAGES
|
| 1117 |
+
relevant_images = []
|
| 1118 |
+
|
| 1119 |
+
try:
|
| 1120 |
+
image_docs_path = f"{FAISS_INDEX_PATH}/image_documents.pkl"
|
| 1121 |
+
|
| 1122 |
+
if os.path.exists(image_docs_path):
|
| 1123 |
+
with open(image_docs_path, "rb") as f:
|
| 1124 |
+
image_docs = pickle.load(f)
|
| 1125 |
+
|
| 1126 |
+
for img_doc in image_docs:
|
| 1127 |
+
img_page = (img_doc['source'], img_doc['page'])
|
| 1128 |
+
if img_page in relevant_pages and len(relevant_images) < k_images:
|
| 1129 |
+
relevant_images.append({
|
| 1130 |
+
"path": img_doc['path'],
|
| 1131 |
+
"source": img_doc['source'],
|
| 1132 |
+
"page": img_doc['page'],
|
| 1133 |
+
"type": "image",
|
| 1134 |
+
"score": 0.0,
|
| 1135 |
+
"rank": len(relevant_images) + 1,
|
| 1136 |
+
"from_page": True
|
| 1137 |
+
})
|
| 1138 |
+
except Exception as e:
|
| 1139 |
+
pass
|
| 1140 |
+
|
| 1141 |
+
return {
|
| 1142 |
+
"text_results": text_results,
|
| 1143 |
+
"images": relevant_images,
|
| 1144 |
+
"query": query,
|
| 1145 |
+
"processed_query": processed_query,
|
| 1146 |
+
"method": "hybrid"
|
| 1147 |
+
}
|
| 1148 |
+
|
| 1149 |
+
|
| 1150 |
+
def display_results_hybrid(results: Dict):
|
| 1151 |
+
"""Display hybrid search results."""
|
| 1152 |
+
|
| 1153 |
+
print("\nπ TOP RESULTS (Hybrid Search: Semantic + BM25):\n")
|
| 1154 |
+
|
| 1155 |
+
for result in results['text_results']:
|
| 1156 |
+
doc = result["document"]
|
| 1157 |
+
print(f"[{result['rank']}] Combined: {result['score']:.4f} "
|
| 1158 |
+
f"(Semantic: {result['semantic_score']:.4f}, BM25: {result['bm25_score']:.4f}) "
|
| 1159 |
+
f"| {doc.metadata.get('type', 'N/A')}")
|
| 1160 |
+
print(f" π {doc.metadata.get('source')} - Page {doc.metadata.get('page')}")
|
| 1161 |
+
print(f" π {doc.metadata.get('heading', 'N/A')[:60]}")
|
| 1162 |
+
|
| 1163 |
+
if 'total_chunks' in doc.metadata and doc.metadata.get('total_chunks', 1) > 1:
|
| 1164 |
+
print(f" π Chunk {doc.metadata.get('chunk_index', 0)+1}/{doc.metadata.get('total_chunks')}")
|
| 1165 |
+
|
| 1166 |
+
print(f" π {doc.page_content[:200]}...")
|
| 1167 |
+
print()
|
| 1168 |
+
|
| 1169 |
+
print("\nπΌοΈ IMAGES:\n")
|
| 1170 |
+
if results['images']:
|
| 1171 |
+
for img in results['images']:
|
| 1172 |
+
print(f"[{img['rank']}] {img['source']} - Page {img['page']}")
|
| 1173 |
+
print(f" {img['path']}\n")
|
| 1174 |
+
else:
|
| 1175 |
+
print(" No images found\n")
|
| 1176 |
+
|
| 1177 |
+
|
| 1178 |
+
# ========================================
|
| 1179 |
+
# π GET CONTEXT WITH PARENTS
|
| 1180 |
+
# ========================================
|
| 1181 |
+
def get_context_with_parents(results: Dict) -> List[Dict]:
|
| 1182 |
+
"""Extract full parent contexts."""
|
| 1183 |
+
|
| 1184 |
+
seen_parents = set()
|
| 1185 |
+
contexts = []
|
| 1186 |
+
|
| 1187 |
+
for result in results['text_results']:
|
| 1188 |
+
doc = result['document']
|
| 1189 |
+
parent = doc.metadata.get('parent_text')
|
| 1190 |
+
|
| 1191 |
+
if parent and parent not in seen_parents:
|
| 1192 |
+
contexts.append({
|
| 1193 |
+
"text": parent,
|
| 1194 |
+
"source": doc.metadata['source'],
|
| 1195 |
+
"page": doc.metadata['page'],
|
| 1196 |
+
"heading": doc.metadata['heading'],
|
| 1197 |
+
"type": doc.metadata.get('type', 'text'),
|
| 1198 |
+
"is_parent": True
|
| 1199 |
+
})
|
| 1200 |
+
seen_parents.add(parent)
|
| 1201 |
+
elif not parent:
|
| 1202 |
+
contexts.append({
|
| 1203 |
+
"text": doc.page_content,
|
| 1204 |
+
"source": doc.metadata['source'],
|
| 1205 |
+
"page": doc.metadata['page'],
|
| 1206 |
+
"heading": doc.metadata['heading'],
|
| 1207 |
+
"type": doc.metadata.get('type', 'text'),
|
| 1208 |
+
"is_parent": False
|
| 1209 |
+
})
|
| 1210 |
+
|
| 1211 |
+
return contexts
|
| 1212 |
+
|
| 1213 |
+
|
| 1214 |
+
# ========================================
|
| 1215 |
+
# π MAIN EXECUTION (UPDATED FOR HYBRID)
|
| 1216 |
+
# ========================================
|
| 1217 |
+
if __name__ == "__main__":
|
| 1218 |
+
print("="*70)
|
| 1219 |
+
print("π RAG with HYBRID SEARCH (Semantic + BM25)")
|
| 1220 |
+
print("="*70 + "\n")
|
| 1221 |
+
|
| 1222 |
+
pdf_files = glob.glob(f"{PDF_DIR}/*.pdf")
|
| 1223 |
+
print(f"π Found {len(pdf_files)} PDF files\n")
|
| 1224 |
+
|
| 1225 |
+
if len(pdf_files) == 0:
|
| 1226 |
+
print("β No PDFs found!")
|
| 1227 |
+
exit(1)
|
| 1228 |
+
|
| 1229 |
+
print("\nπ€ Loading VLM2Vec model...")
|
| 1230 |
+
embedding_model = VLM2VecEmbeddings(
|
| 1231 |
+
model_name="TIGER-Lab/VLM2Vec-Qwen2VL-2B",
|
| 1232 |
+
cache_dir=MODEL_CACHE_DIR
|
| 1233 |
+
)
|
| 1234 |
+
|
| 1235 |
+
# Load or build index
|
| 1236 |
+
if os.path.exists(f"{FAISS_INDEX_PATH}/text_index.faiss"):
|
| 1237 |
+
print(f"β
Loading existing index\n")
|
| 1238 |
+
|
| 1239 |
+
if not os.path.exists(f"{FAISS_INDEX_PATH}/bm25_index.pkl"):
|
| 1240 |
+
print("β οΈ BM25 index missing, building now...")
|
| 1241 |
+
|
| 1242 |
+
with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "rb") as f:
|
| 1243 |
+
all_texts = pickle.load(f)
|
| 1244 |
+
|
| 1245 |
+
print(" Building BM25 index...")
|
| 1246 |
+
tokenized_docs = [doc.page_content.lower().split() for doc in all_texts]
|
| 1247 |
+
bm25_index = BM25Okapi(tokenized_docs, k1=1.3, b=0.65)
|
| 1248 |
+
|
| 1249 |
+
with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "wb") as f:
|
| 1250 |
+
pickle.dump(bm25_index, f)
|
| 1251 |
+
|
| 1252 |
+
print(" β
BM25 index saved\n")
|
| 1253 |
+
|
| 1254 |
+
else:
|
| 1255 |
+
print("π¨ Building new index...\n")
|
| 1256 |
+
|
| 1257 |
+
embedding_model = VLM2VecEmbeddings(
|
| 1258 |
+
model_name="TIGER-Lab/VLM2Vec-Qwen2VL-2B",
|
| 1259 |
+
cache_dir=MODEL_CACHE_DIR
|
| 1260 |
+
)
|
| 1261 |
+
|
| 1262 |
+
index, documents = build_multimodal_faiss_streaming(pdf_files, embedding_model)
|
| 1263 |
+
|
| 1264 |
+
if index is None:
|
| 1265 |
+
exit(0)
|
| 1266 |
+
|
| 1267 |
+
# Interactive testing
|
| 1268 |
+
print("="*70)
|
| 1269 |
+
print("π§ͺ TESTING MODE - HYBRID SEARCH")
|
| 1270 |
+
print(f" Weights: Semantic {DENSE_WEIGHT} | BM25 {SPARSE_WEIGHT}")
|
| 1271 |
+
print("="*70 + "\n")
|
| 1272 |
+
|
| 1273 |
+
test_queries = [
|
| 1274 |
+
"What is the higher and lower explosive limit of butane?",
|
| 1275 |
+
"What are the precautions taken while handling H2S?",
|
| 1276 |
+
"What are the Personal Protection used for Sulfolane?",
|
| 1277 |
+
"What is the Composition of Platforming Feed and Product?",
|
| 1278 |
+
"Explain Dual function platforming catalyst chemistry.",
|
| 1279 |
+
"Steps to be followed in Amine Regeneration Unit for normal shutdown process.",
|
| 1280 |
+
"Could you tell me what De-greasing of Amine System in pre startup wash",
|
| 1281 |
+
]
|
| 1282 |
+
|
| 1283 |
+
print("π SUGGESTED QUERIES:")
|
| 1284 |
+
for i, q in enumerate(test_queries, 1):
|
| 1285 |
+
print(f" {i}. {q}")
|
| 1286 |
+
print()
|
| 1287 |
+
print("π‘ Type 'mode' to switch between hybrid/bm25/semantic")
|
| 1288 |
+
print()
|
| 1289 |
+
|
| 1290 |
+
current_mode = "hybrid"
|
| 1291 |
+
|
| 1292 |
+
while True:
|
| 1293 |
+
user_query = input(f"π¬ Query [{current_mode}] (or 1-5, 'mode', or 'exit'): ").strip()
|
| 1294 |
+
|
| 1295 |
+
if user_query.lower() == 'exit':
|
| 1296 |
+
print("\nβ
Done!")
|
| 1297 |
+
break
|
| 1298 |
+
|
| 1299 |
+
if user_query.lower() == 'mode':
|
| 1300 |
+
print("\nπ Select mode:")
|
| 1301 |
+
print(" 1. Hybrid (Semantic + BM25)")
|
| 1302 |
+
print(" 2. BM25 only")
|
| 1303 |
+
print(" 3. Semantic only")
|
| 1304 |
+
mode_choice = input(" Choice (1-3): ").strip()
|
| 1305 |
+
|
| 1306 |
+
if mode_choice == '1':
|
| 1307 |
+
current_mode = "hybrid"
|
| 1308 |
+
elif mode_choice == '2':
|
| 1309 |
+
current_mode = "bm25"
|
| 1310 |
+
elif mode_choice == '3':
|
| 1311 |
+
current_mode = "semantic"
|
| 1312 |
+
|
| 1313 |
+
print(f" β
Mode set to: {current_mode}\n")
|
| 1314 |
+
continue
|
| 1315 |
+
|
| 1316 |
+
if user_query.isdigit() and 1 <= int(user_query) <= len(test_queries):
|
| 1317 |
+
user_query = test_queries[int(user_query) - 1]
|
| 1318 |
+
|
| 1319 |
+
if not user_query:
|
| 1320 |
+
continue
|
| 1321 |
+
|
| 1322 |
+
print(f"\n{'='*60}")
|
| 1323 |
+
print(f"π Query: {user_query}")
|
| 1324 |
+
print(f"π§ Mode: {current_mode.upper()}")
|
| 1325 |
+
print(f"{'='*60}\n")
|
| 1326 |
+
|
| 1327 |
+
try:
|
| 1328 |
+
if current_mode == "hybrid":
|
| 1329 |
+
results = query_with_hybrid(user_query, embedding_model, k_text=5, k_images=3)
|
| 1330 |
+
display_results_hybrid(results)
|
| 1331 |
+
elif current_mode == "bm25":
|
| 1332 |
+
results = query_with_bm25(user_query, k_text=5, k_images=3)
|
| 1333 |
+
display_results_bm25(results)
|
| 1334 |
+
else: # semantic only
|
| 1335 |
+
results = query_with_hybrid(user_query, embedding_model, k_text=5, k_images=3,
|
| 1336 |
+
dense_weight=1.0, sparse_weight=0.0)
|
| 1337 |
+
display_results_hybrid(results)
|
| 1338 |
+
|
| 1339 |
+
print("\nπ FULL CONTEXT:\n")
|
| 1340 |
+
contexts = get_context_with_parents(results)
|
| 1341 |
+
|
| 1342 |
+
for i, ctx in enumerate(contexts[:3], 1):
|
| 1343 |
+
print(f"[{i}] {ctx['heading'][:50]}")
|
| 1344 |
+
if ctx['is_parent']:
|
| 1345 |
+
print(f" β
Full section")
|
| 1346 |
+
print(f" {ctx['text'][:300]}...\n")
|
| 1347 |
+
|
| 1348 |
+
print("="*60 + "\n")
|
| 1349 |
+
|
| 1350 |
+
except Exception as e:
|
| 1351 |
+
print(f"\nβ Error: {e}\n")
|
| 1352 |
+
import traceback
|
| 1353 |
+
traceback.print_exc()
|
| 1354 |
+
|