Michtiii commited on
Commit
d4d1a0c
·
verified ·
1 Parent(s): d78645a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +125 -42
  2. requirements.txt +7 -7
app.py CHANGED
@@ -1,65 +1,148 @@
1
  import os
2
- from PyPDF2 import PdfReader
3
- from langchain.text_splitter import RecursiveCharacterTextSplitter
4
- from langchain.vectorstores import FAISS
5
- from langchain.embeddings import SentenceTransformerEmbeddings
6
- from langchain.chains import RetrievalQA
7
- from langchain.chat_models import ChatOpenAI # or HuggingFaceChatModel
8
  import gradio as gr
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # -----------------------------
11
- # 1. Load PDF files
12
  # -----------------------------
13
- docs_path = "Docs"
14
- all_texts = []
 
15
 
16
- for file in os.listdir(docs_path):
17
- if file.endswith(".pdf"):
18
- pdf = PdfReader(os.path.join(docs_path, file))
 
 
 
 
 
 
19
  text = ""
20
- for page in pdf.pages:
21
  text += page.extract_text() or ""
22
- all_texts.append(text)
 
23
 
24
- full_text = "\n".join(all_texts)
 
 
 
 
 
 
 
 
 
 
25
 
26
  # -----------------------------
27
- # 2. Split text into chunks
28
  # -----------------------------
29
- text_splitter = RecursiveCharacterTextSplitter(
30
- chunk_size=1000,
31
- chunk_overlap=200
32
- )
33
- texts = text_splitter.split_text(full_text)
 
34
 
35
  # -----------------------------
36
- # 3. Create embeddings and vector store
37
  # -----------------------------
38
- embedding_model = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
39
- vectorstore = FAISS.from_texts(texts, embedding_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # -----------------------------
42
- # 4. Create retrieval QA chain
43
  # -----------------------------
44
- llm = ChatOpenAI(temperature=0) # or use HuggingFace model if you prefer
45
- qa = RetrievalQA.from_chain_type(
46
- llm=llm,
47
- retriever=vectorstore.as_retriever(),
48
- chain_type="stuff" # simple summarization chain
49
- )
50
 
51
  # -----------------------------
52
- # 5. Gradio interface
53
  # -----------------------------
54
- def answer_question(query):
55
- return qa.run(query)
 
 
 
 
 
 
 
 
 
 
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  with gr.Blocks() as demo:
58
- gr.Markdown("# PDF RAG + Summarization Chatbot")
59
- with gr.Row():
60
- query_input = gr.Textbox(label="Ask a question about your PDFs")
61
- output_box = gr.Textbox(label="Answer")
62
- query_input.submit(answer_question, inputs=query_input, outputs=output_box)
63
- gr.Button("Submit").click(answer_question, inputs=query_input, outputs=output_box)
64
-
65
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
1
  import os
2
+ import faiss
3
+ import numpy as np
 
 
 
 
4
  import gradio as gr
5
 
6
+ from sentence_transformers import SentenceTransformer
7
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
+ from PyPDF2 import PdfReader
9
+
10
+ # -----------------------------
11
+ # CONFIG
12
+ # -----------------------------
13
+ DATA_PATH = "Docs"
14
+ TOP_K = 3
15
+
16
+ # -----------------------------
17
+ # EMBEDDING MODEL (LIGHT)
18
+ # -----------------------------
19
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
20
+
21
  # -----------------------------
22
+ # OPEN LLM (NO AUTH REQUIRED)
23
  # -----------------------------
24
+ LLM_MODEL = "google/flan-t5-base"
25
+ tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
26
+ llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL)
27
 
28
+ # -----------------------------
29
+ # FILE LOADER
30
+ # -----------------------------
31
+ def read_file(path):
32
+ if path.endswith(".txt") or path.endswith(".md"):
33
+ with open(path, "r", encoding="utf-8") as f:
34
+ return f.read()
35
+ elif path.endswith(".pdf"):
36
+ reader = PdfReader(path)
37
  text = ""
38
+ for page in reader.pages:
39
  text += page.extract_text() or ""
40
+ return text
41
+ return ""
42
 
43
+ def load_docs(folder):
44
+ texts = []
45
+ for file in os.listdir(folder):
46
+ path = os.path.join(folder, file)
47
+ try:
48
+ txt = read_file(path)
49
+ if txt.strip():
50
+ texts.append(txt)
51
+ except:
52
+ continue
53
+ return texts
54
 
55
  # -----------------------------
56
+ # CHUNKING
57
  # -----------------------------
58
+ def chunk_text(text, size=300, overlap=50):
59
+ words = text.split()
60
+ chunks = []
61
+ for i in range(0, len(words), size - overlap):
62
+ chunks.append(" ".join(words[i:i + size]))
63
+ return chunks
64
 
65
  # -----------------------------
66
+ # BUILD VECTOR DB
67
  # -----------------------------
68
+ def build_index(docs):
69
+ chunks = []
70
+ for doc in docs:
71
+ chunks.extend(chunk_text(doc))
72
+
73
+ if not chunks:
74
+ return None, []
75
+
76
+ embeddings = embedding_model.encode(chunks)
77
+ dim = embeddings.shape[1]
78
+
79
+ index = faiss.IndexFlatL2(dim)
80
+ index.add(np.array(embeddings))
81
+
82
+ return index, chunks
83
 
84
  # -----------------------------
85
+ # RETRIEVE
86
  # -----------------------------
87
+ def retrieve(query, index, chunks, k=TOP_K):
88
+ q_embed = embedding_model.encode([query])
89
+ D, I = index.search(np.array(q_embed), k)
90
+ return [chunks[i] for i in I[0]]
 
 
91
 
92
  # -----------------------------
93
+ # GENERATE ANSWER
94
  # -----------------------------
95
+ def generate_answer(query, contexts):
96
+ context = "\n\n".join(contexts)
97
+
98
+ prompt = f"""
99
+ Answer the question based ONLY on the context.
100
+ If not found, say: Not in knowledge base.
101
+
102
+ Context:
103
+ {context}
104
+
105
+ Question:
106
+ {query}
107
+ """
108
 
109
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
110
+ outputs = llm_model.generate(**inputs, max_new_tokens=200)
111
+
112
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
113
+
114
+ # -----------------------------
115
+ # INIT
116
+ # -----------------------------
117
+ docs = load_docs(DATA_PATH)
118
+ index, chunks = build_index(docs)
119
+
120
+ # -----------------------------
121
+ # RAG PIPELINE
122
+ # -----------------------------
123
+ def rag(query):
124
+ if index is None:
125
+ return "No documents found", ""
126
+
127
+ retrieved = retrieve(query, index, chunks)
128
+ answer = generate_answer(query, retrieved)
129
+
130
+ return answer, "\n\n---\n\n".join(retrieved)
131
+
132
+ # -----------------------------
133
+ # UI
134
+ # -----------------------------
135
  with gr.Blocks() as demo:
136
+ gr.Markdown("## AI/ML Knowledge RAG (Stable Version)")
137
+
138
+ q = gr.Textbox(placeholder="Ask about AI tools, companies, ML...")
139
+ ans = gr.Textbox(label="Answer")
140
+ ctx = gr.Textbox(label="Context")
141
+
142
+ gr.Button("Ask").click(rag, inputs=q, outputs=[ans, ctx])
143
+
144
+ # -----------------------------
145
+ # RUN
146
+ # -----------------------------
147
+ if __name__ == "__main__":
148
+ demo.launch()
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
- langchain==0.1.232
2
- gradio==6.10.0
3
- PyPDF2==3.0.1
4
- faiss-cpu==1.7.4
5
- sentence-transformers==2.2.2
6
- huggingface-hub==0.30.0
7
- transformers==4.33.2
 
1
+ gradio
2
+ faiss-cpu
3
+ sentence-transformers
4
+ transformers
5
+ torch
6
+ PyPDF2
7
+ numpy