Isshi14 commited on
Commit
3e4a391
·
verified ·
1 Parent(s): 7f8878c

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +119 -104
  2. requirements.txt +0 -5
app.py CHANGED
@@ -1,69 +1,93 @@
1
  import os
2
  import gradio as gr
3
- from langchain_community.document_loaders import DirectoryLoader, TextLoader
4
- from langchain_text_splitters import RecursiveCharacterTextSplitter
5
- from langchain_huggingface import HuggingFaceEmbeddings
6
- from langchain_community.vectorstores import Chroma
7
- from langchain_huggingface import HuggingFaceEndpoint
8
- from langchain.chains import RetrievalQA
9
- from langchain_core.prompts import PromptTemplate
10
 
11
  # --- Configuration ---
12
- # You can set your Hugging Face Token here or as an environment variable
13
- # os.environ["HUGGINGFACEHUB_API_TOKEN"] = "your_token_here"
14
-
15
  KNOWLEDGE_BASE_DIR = "knowledge_base"
16
- PERSIST_DIRECTORY = "chroma_db"
17
 
 
18
  def load_documents():
19
- """Loads text documents from the knowledge base directory."""
20
- loader = DirectoryLoader(KNOWLEDGE_BASE_DIR, glob="*.txt", loader_cls=TextLoader)
21
- documents = loader.load()
22
- return documents
23
-
24
- def create_vector_store(documents):
25
- """Splits documents and creates a Chroma vector store."""
26
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
27
- texts = text_splitter.split_documents(documents)
28
-
29
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- # Check if vector store already exists to avoid re-creating it every time?
32
- # For this assignment, re-creating it ensures latest data is used.
33
- if os.path.exists(PERSIST_DIRECTORY):
34
- try:
35
- # simple cleanup for fresh clear start (optional for production but good for dev)
36
- import shutil
37
- shutil.rmtree(PERSIST_DIRECTORY)
38
- except:
39
- pass
40
-
41
- vector_store = Chroma.from_documents(texts, embeddings, persist_directory=PERSIST_DIRECTORY)
42
- return vector_store
43
-
44
- def setup_rag_chain(vector_store):
45
- """Sets up the RAG chain with a retrieval capability."""
46
- # Using a free endpoint model.
47
- # 'mistralai/Mistral-7B-Instruct-v0.2' is a good choice, but requires a token.
48
- # 'google/flan-t5-large' is another option.
49
- # We'll use a generic reliable one or let the user input their token/model in a real scenario.
50
- # For the assignment, let's try to use a model that might work with the free tier or a locally downloadable one if needed.
51
- # However, running local LLM is heavy.
52
- # Let's assume the user has a token or we use a very small model.
53
- # If no token is found, this might fail or warn.
54
 
55
- llm = HuggingFaceEndpoint(
56
- repo_id="mistralai/Mistral-7B-Instruct-v0.2",
57
- task="text-generation",
58
- max_new_tokens=512,
59
- do_sample=False,
60
- repetition_penalty=1.03,
 
 
 
 
 
 
 
 
 
 
 
61
  )
62
 
63
- retriever = vector_store.as_retriever(search_kwargs={"k": 3})
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- prompt_template = """Use the following pieces of context to answer the question at the end.
66
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
 
67
 
68
  Context:
69
  {context}
@@ -71,74 +95,65 @@ Context:
71
  Question: {question}
72
 
73
  Answer:"""
74
- PROMPT = PromptTemplate(
75
- template=prompt_template, input_variables=["context", "question"]
76
- )
77
-
78
- qa_chain = RetrievalQA.from_chain_type(
79
- llm=llm,
80
- chain_type="stuff",
81
- retriever=retriever,
82
- return_source_documents=True,
83
- chain_type_kwargs={"prompt": PROMPT}
84
- )
85
- return qa_chain
86
 
87
  # --- Global Initialization ---
88
  print("Loading documents...")
89
- docs = load_documents()
90
- print(f"Loaded {len(docs)} documents.")
91
-
92
- print("Creating vector store...")
93
- vector_db = create_vector_store(docs)
94
- print("Vector store created.")
95
-
96
- print("Setting up RAG chain...")
97
- try:
98
- rag_chain = setup_rag_chain(vector_db)
99
- print("RAG chain setup complete.")
100
- except Exception as e:
101
- print(f"Error setting up RAG chain (likely missing HF Token): {e}")
102
- rag_chain = None
103
-
104
- def ask_ai_twin(question):
105
- if not rag_chain:
106
- return "Error: RAG Chain not initialized. Please check your Hugging Face Token."
107
-
108
- result = rag_chain.invoke({"query": question})
109
- return result["result"]
110
 
111
  # --- Gradio UI ---
112
  def load_profile_summary():
113
  try:
114
- with open(os.path.join(KNOWLEDGE_BASE_DIR, "profile.txt"), "r") as f:
115
  return f.read()
116
  except FileNotFoundError:
117
  return "Profile not found."
118
 
119
- with gr.Blocks(title="My AI Twin") as demo:
120
- gr.Markdown("# My AI Twin")
 
 
 
 
 
121
  gr.Markdown("Ask me anything about my professional background, skills, and projects!")
122
 
123
  with gr.Row():
124
  with gr.Column(scale=1):
125
- gr.Markdown("### Profile Summary")
126
  profile_content = load_profile_summary()
127
- gr.Textbox(value=profile_content, label="About Me", interactive=False, lines=10)
128
 
129
  with gr.Column(scale=2):
130
- chatbot = gr.Chatbot(label="Conversation")
131
- msg = gr.Textbox(label="Ask a question")
132
- submit_btn = gr.Button("Submit")
133
- clear = gr.Button("Clear")
134
-
135
- def respond(message, chat_history):
136
- bot_message = ask_ai_twin(message)
137
- chat_history.append((message, bot_message))
138
- return "", chat_history
139
-
140
- msg.submit(respond, [msg, chatbot], [msg, chatbot])
141
- submit_btn.click(respond, [msg, chatbot], [msg, chatbot])
142
  clear.click(lambda: None, None, chatbot, queue=False)
143
 
144
  if __name__ == "__main__":
 
1
  import os
2
  import gradio as gr
3
+ from sentence_transformers import SentenceTransformer
4
+ import chromadb
5
+ from huggingface_hub import InferenceClient
 
 
 
 
6
 
7
  # --- Configuration ---
 
 
 
8
  KNOWLEDGE_BASE_DIR = "knowledge_base"
9
+ COLLECTION_NAME = "ai_twin_kb"
10
 
11
+ # --- Step 1: Load documents from knowledge_base/ ---
12
  def load_documents():
13
+ """Loads all .txt files from the knowledge base directory."""
14
+ documents = []
15
+ filenames = []
16
+ for filename in os.listdir(KNOWLEDGE_BASE_DIR):
17
+ if filename.endswith(".txt"):
18
+ filepath = os.path.join(KNOWLEDGE_BASE_DIR, filename)
19
+ with open(filepath, "r", encoding="utf-8", errors="ignore") as f:
20
+ content = f.read().strip()
21
+ if content:
22
+ documents.append(content)
23
+ filenames.append(filename)
24
+ return documents, filenames
25
+
26
+ # --- Step 2: Chunk documents ---
27
+ def chunk_text(text, chunk_size=500, overlap=100):
28
+ """Splits text into overlapping chunks."""
29
+ chunks = []
30
+ start = 0
31
+ while start < len(text):
32
+ end = start + chunk_size
33
+ chunks.append(text[start:end])
34
+ start += chunk_size - overlap
35
+ return chunks
36
+
37
+ # --- Step 3: Build vector store ---
38
+ def build_vector_store(documents, filenames):
39
+ """Creates embeddings and stores them in ChromaDB."""
40
+ model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
41
+
42
+ client = chromadb.Client()
43
+ # Delete existing collection if it exists
44
+ try:
45
+ client.delete_collection(COLLECTION_NAME)
46
+ except:
47
+ pass
48
+ collection = client.create_collection(name=COLLECTION_NAME)
49
 
50
+ all_chunks = []
51
+ all_ids = []
52
+ all_metadata = []
53
+ chunk_id = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ for doc, fname in zip(documents, filenames):
56
+ chunks = chunk_text(doc)
57
+ for chunk in chunks:
58
+ all_chunks.append(chunk)
59
+ all_ids.append(f"chunk_{chunk_id}")
60
+ all_metadata.append({"source": fname})
61
+ chunk_id += 1
62
+
63
+ # Generate embeddings
64
+ embeddings = model.encode(all_chunks).tolist()
65
+
66
+ # Add to ChromaDB
67
+ collection.add(
68
+ documents=all_chunks,
69
+ embeddings=embeddings,
70
+ ids=all_ids,
71
+ metadatas=all_metadata
72
  )
73
 
74
+ return collection, model
75
+
76
+ # --- Step 4: RAG query function ---
77
+ def query_rag(question, collection, embed_model, llm_client):
78
+ """Retrieves relevant chunks and generates an answer."""
79
+ # Embed the question
80
+ q_embedding = embed_model.encode([question]).tolist()
81
+
82
+ # Retrieve top 3 relevant chunks
83
+ results = collection.query(query_embeddings=q_embedding, n_results=3)
84
+
85
+ # Build context from retrieved documents
86
+ context = "\n\n".join(results["documents"][0])
87
 
88
+ # Create prompt
89
+ prompt = f"""You are an AI Twin that represents a person. Use ONLY the following context to answer the question.
90
+ If you don't know the answer from the context, say "I don't have that information in my profile."
91
 
92
  Context:
93
  {context}
 
95
  Question: {question}
96
 
97
  Answer:"""
98
+
99
+ # Generate response using Hugging Face Inference API
100
+ try:
101
+ response = llm_client.text_generation(
102
+ prompt,
103
+ max_new_tokens=512,
104
+ temperature=0.3,
105
+ repetition_penalty=1.1
106
+ )
107
+ return response.strip()
108
+ except Exception as e:
109
+ return f"Error generating response: {str(e)}"
110
 
111
  # --- Global Initialization ---
112
  print("Loading documents...")
113
+ docs, fnames = load_documents()
114
+ print(f"Loaded {len(docs)} documents: {fnames}")
115
+
116
+ print("Building vector store...")
117
+ kb_collection, embedding_model = build_vector_store(docs, fnames)
118
+ print("Vector store ready.")
119
+
120
+ print("Initializing LLM client...")
121
+ hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN", None)
122
+ llm = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.2", token=hf_token)
123
+ print("LLM client ready.")
 
 
 
 
 
 
 
 
 
 
124
 
125
  # --- Gradio UI ---
126
  def load_profile_summary():
127
  try:
128
+ with open(os.path.join(KNOWLEDGE_BASE_DIR, "profile.txt"), "r", encoding="utf-8") as f:
129
  return f.read()
130
  except FileNotFoundError:
131
  return "Profile not found."
132
 
133
+ def ask_ai_twin(message, chat_history):
134
+ answer = query_rag(message, kb_collection, embedding_model, llm)
135
+ chat_history.append((message, answer))
136
+ return "", chat_history
137
+
138
+ with gr.Blocks(title="My AI Twin", theme=gr.themes.Soft()) as demo:
139
+ gr.Markdown("# 🤖 My AI Twin")
140
  gr.Markdown("Ask me anything about my professional background, skills, and projects!")
141
 
142
  with gr.Row():
143
  with gr.Column(scale=1):
144
+ gr.Markdown("### 📋 Profile Summary")
145
  profile_content = load_profile_summary()
146
+ gr.Textbox(value=profile_content, label="About Me", interactive=False, lines=15)
147
 
148
  with gr.Column(scale=2):
149
+ chatbot = gr.Chatbot(label="Conversation", height=400)
150
+ msg = gr.Textbox(label="Ask a question", placeholder="e.g. What are my skills?")
151
+ with gr.Row():
152
+ submit_btn = gr.Button("Submit", variant="primary")
153
+ clear = gr.Button("Clear")
154
+
155
+ msg.submit(ask_ai_twin, [msg, chatbot], [msg, chatbot])
156
+ submit_btn.click(ask_ai_twin, [msg, chatbot], [msg, chatbot])
 
 
 
 
157
  clear.click(lambda: None, None, chatbot, queue=False)
158
 
159
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,8 +1,3 @@
1
- langchain
2
- langchain-community
3
- langchain-huggingface
4
- langchain-text-splitters
5
- langchain-core
6
  chromadb
7
  sentence-transformers
8
  gradio
 
 
 
 
 
 
1
  chromadb
2
  sentence-transformers
3
  gradio