Anton Novoselov commited on
Commit
9563789
·
1 Parent(s): 6048c1e

add agent

Browse files
Files changed (5) hide show
  1. .gitignore +115 -0
  2. agent.py +265 -0
  3. app.py +18 -5
  4. requirements.txt +17 -2
  5. system_prompt.txt +17 -0
.gitignore ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+
7
+ # Distribution / packaging
8
+ .Python
9
+ build/
10
+ develop-eggs/
11
+ dist/
12
+ downloads/
13
+ eggs/
14
+ .eggs/
15
+ lib/
16
+ lib64/
17
+ parts/
18
+ sdist/
19
+ var/
20
+ wheels/
21
+ *.egg-info/
22
+ .installed.cfg
23
+ *.egg
24
+
25
+ # Virtual environments
26
+ venv/
27
+ ENV/
28
+ env/
29
+ .env
30
+ .venv
31
+ env.bak/
32
+ venv.bak/
33
+ .python-version
34
+
35
+ # Unit test / coverage reports
36
+ htmlcov/
37
+ .tox/
38
+ .nox/
39
+ .coverage
40
+ .coverage.*
41
+ .cache
42
+ nosetests.xml
43
+ coverage.xml
44
+ *.cover
45
+ .hypothesis/
46
+ .pytest_cache/
47
+ pytest-*.xml
48
+
49
+ # Jupyter Notebook
50
+ .ipynb_checkpoints
51
+
52
+ # IPython
53
+ profile_default/
54
+ ipython_config.py
55
+
56
+ # Logs
57
+ *.log
58
+ logs/
59
+ log/
60
+
61
+ # IDE specific files
62
+ .idea/
63
+ .vscode/
64
+ *.swp
65
+ *.swo
66
+ *~
67
+ .DS_Store
68
+ .project
69
+ .pydevproject
70
+ .settings/
71
+ .vs/
72
+ *.sublime-project
73
+ *.sublime-workspace
74
+
75
+ # Database
76
+ *.db
77
+ *.rdb
78
+ *.sqlite
79
+ *.sqlite3
80
+
81
+ # Environment variables
82
+ .env
83
+ .env.local
84
+ .env.development.local
85
+ .env.test.local
86
+ .env.production.local
87
+
88
+ # macOS specific
89
+ .DS_Store
90
+ .AppleDouble
91
+ .LSOverride
92
+ Icon
93
+ ._*
94
+ .DocumentRevisions-V100
95
+ .fseventsd
96
+ .Spotlight-V100
97
+ .TemporaryItems
98
+ .Trashes
99
+ .VolumeIcon.icns
100
+ .com.apple.timemachine.donotpresent
101
+
102
+ # AI/model files
103
+ *.h5
104
+ *.pb
105
+ *.onnx
106
+ *.tflite
107
+ *.pt
108
+ *.pth
109
+ *.weights
110
+
111
+ # Temporary files
112
+ tmp/
113
+ temp/
114
+ .tmp
115
+ *.tmp
agent.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent"""
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
+ from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain_community.document_loaders import WikipediaLoader
12
+ from langchain_community.document_loaders import ArxivLoader
13
+ from langchain_core.messages import SystemMessage, HumanMessage
14
+ from langchain_core.tools import tool
15
+ from langchain.tools.retriever import create_retriever_tool
16
+ from langchain_community.vectorstores import Chroma
17
+ from langchain_core.documents import Document
18
+ import shutil
19
+ import pandas as pd
20
+ import json
21
+
22
+ load_dotenv()
23
+
24
+ @tool
25
+ def multiply(a: int, b: int) -> int:
26
+ """Multiply two numbers.
27
+ Args:
28
+ a: first int
29
+ b: second int
30
+ """
31
+ return a * b
32
+
33
+ @tool
34
+ def add(a: int, b: int) -> int:
35
+ """Add two numbers.
36
+
37
+ Args:
38
+ a: first int
39
+ b: second int
40
+ """
41
+ return a + b
42
+
43
+ @tool
44
+ def subtract(a: int, b: int) -> int:
45
+ """Subtract two numbers.
46
+
47
+ Args:
48
+ a: first int
49
+ b: second int
50
+ """
51
+ return a - b
52
+
53
+ @tool
54
+ def divide(a: int, b: int) -> int:
55
+ """Divide two numbers.
56
+
57
+ Args:
58
+ a: first int
59
+ b: second int
60
+ """
61
+ if b == 0:
62
+ raise ValueError("Cannot divide by zero.")
63
+ return a / b
64
+
65
+ @tool
66
+ def modulus(a: int, b: int) -> int:
67
+ """Get the modulus of two numbers.
68
+
69
+ Args:
70
+ a: first int
71
+ b: second int
72
+ """
73
+ return a % b
74
+
75
+ @tool
76
+ def wiki_search(query: str) -> str:
77
+ """Search Wikipedia for a query and return maximum 2 results.
78
+
79
+ Args:
80
+ query: The search query."""
81
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
82
+ formatted_search_docs = "\n\n---\n\n".join(
83
+ [
84
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
85
+ for doc in search_docs
86
+ ])
87
+ return {"wiki_results": formatted_search_docs}
88
+
89
+ @tool
90
+ def web_search(query: str) -> str:
91
+ """Search Tavily for a query and return maximum 3 results.
92
+
93
+ Args:
94
+ query: The search query."""
95
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
96
+ formatted_search_docs = "\n\n---\n\n".join(
97
+ [
98
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
99
+ for doc in search_docs
100
+ ])
101
+ return {"web_results": formatted_search_docs}
102
+
103
+ @tool
104
+ def arvix_search(query: str) -> str:
105
+ """Search Arxiv for a query and return maximum 3 result.
106
+
107
+ Args:
108
+ query: The search query."""
109
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
110
+ formatted_search_docs = "\n\n---\n\n".join(
111
+ [
112
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
113
+ for doc in search_docs
114
+ ])
115
+ return {"arvix_results": formatted_search_docs}
116
+
117
+ # load the system prompt from the file
118
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
119
+ system_prompt = f.read()
120
+
121
+ # System message
122
+ sys_msg = SystemMessage(content=system_prompt)
123
+
124
+ # --- Start ChromaDB Setup ---
125
+ # Define the directory for ChromaDB persistence
126
+ CHROMA_DB_DIR = "./chroma_db"
127
+ CSV_FILE_PATH = "./supabase_docs.csv" # Path to your CSV file
128
+
129
+ # Build embeddings (this remains the same)
130
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
131
+
132
+ # Initialize ChromaDB
133
+ # If the directory exists and contains data, load the existing vector store.
134
+ # Otherwise, create a new one and add documents from the CSV file.
135
+ if os.path.exists(CHROMA_DB_DIR) and os.listdir(CHROMA_DB_DIR):
136
+ print(f"Loading existing ChromaDB from {CHROMA_DB_DIR}")
137
+ vector_store = Chroma(
138
+ persist_directory=CHROMA_DB_DIR,
139
+ embedding_function=embeddings
140
+ )
141
+ else:
142
+ print(f"Creating new ChromaDB at {CHROMA_DB_DIR} and loading documents from {CSV_FILE_PATH}.")
143
+ # Ensure the directory is clean before creating new
144
+ if os.path.exists(CHROMA_DB_DIR):
145
+ shutil.rmtree(CHROMA_DB_DIR)
146
+ os.makedirs(CHROMA_DB_DIR)
147
+
148
+ # Load data from the CSV file
149
+ if not os.path.exists(CSV_FILE_PATH):
150
+ raise FileNotFoundError(f"CSV file not found at {CSV_FILE_PATH}. Please ensure it's in the root directory.")
151
+
152
+ df = pd.read_csv(CSV_FILE_PATH)
153
+ documents = []
154
+ for index, row in df.iterrows():
155
+ content = row["content"]
156
+
157
+ # Extract the question part from the content
158
+ # Assuming the question is everything before "Final answer :"
159
+ question_part = content.split("Final answer :")[0].strip()
160
+
161
+ # Extract the final answer part from the content
162
+ final_answer_part = content.split("Final answer :")[-1].strip() if "Final answer :" in content else ""
163
+
164
+ # Parse the metadata string into a dictionary
165
+ # The metadata column might be stored as a string representation of a dictionary
166
+ try:
167
+ metadata = json.loads(row["metadata"].replace("'", "\"")) # Replace single quotes for valid JSON
168
+ except json.JSONDecodeError:
169
+ metadata = {} # Fallback if parsing fails
170
+
171
+ # Add the extracted final answer to the metadata for easy retrieval
172
+ metadata["final_answer"] = final_answer_part
173
+
174
+ # Create a Document object. The page_content should be the question for similarity search.
175
+ # The answer will be in metadata.
176
+ documents.append(Document(page_content=question_part, metadata=metadata))
177
+
178
+ if not documents:
179
+ print("No documents loaded from CSV. ChromaDB will be empty.")
180
+ # Create an empty ChromaDB if no documents are found
181
+ vector_store = Chroma(
182
+ persist_directory=CHROMA_DB_DIR,
183
+ embedding_function=embeddings
184
+ )
185
+ else:
186
+ vector_store = Chroma.from_documents(
187
+ documents=documents,
188
+ embedding=embeddings,
189
+ persist_directory=CHROMA_DB_DIR
190
+ )
191
+ vector_store.persist() # Save the new vector store to disk
192
+ print(f"ChromaDB initialized and persisted with {len(documents)} documents from CSV.")
193
+
194
+ # Create retriever tool using the Chroma vector store
195
+ retriever_tool = create_retriever_tool(
196
+ retriever=vector_store.as_retriever(),
197
+ name="Question_Search",
198
+ description="A tool to retrieve similar questions from a vector store. The retrieved document's metadata contains the 'final_answer' to the question.",
199
+ )
200
+
201
+ # Add the new retriever tool to your list of tools
202
+ tools = [
203
+ multiply,
204
+ add,
205
+ subtract,
206
+ divide,
207
+ modulus,
208
+ wiki_search,
209
+ web_search,
210
+ arvix_search,
211
+ retriever_tool,
212
+ ]
213
+
214
+ # Build graph function
215
+ def build_graph(provider: str = "google"):
216
+ """Build the graph"""
217
+ if provider == "google":
218
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
219
+ elif provider == "groq":
220
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
221
+ elif provider == "huggingface":
222
+ llm = ChatHuggingFace(
223
+ llm=HuggingFaceEndpoint(
224
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
225
+ temperature=0,
226
+ ),
227
+ )
228
+ else:
229
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
230
+
231
+ llm_with_tools = llm.bind_tools(tools)
232
+
233
+ def assistant(state: MessagesState):
234
+ """Assistant node"""
235
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
236
+
237
+ from langchain_core.messages import AIMessage
238
+
239
+ def retriever(state: MessagesState):
240
+ query = state["messages"][-1].content
241
+ # Use the vector_store directly for similarity search to get the full Document object
242
+ similar_docs = vector_store.similarity_search(query, k=1)
243
+
244
+ if similar_docs:
245
+ similar_doc = similar_docs[0]
246
+ # Prioritize 'final_answer' from metadata, then check page_content
247
+ if "final_answer" in similar_doc.metadata and similar_doc.metadata["final_answer"]:
248
+ answer = similar_doc.metadata["final_answer"]
249
+ elif "Final answer :" in similar_doc.page_content:
250
+ answer = similar_doc.page_content.split("Final answer :")[-1].strip()
251
+ else:
252
+ answer = similar_doc.page_content.strip() # Fallback to page_content if no explicit answer
253
+
254
+ # The system prompt expects "FINAL ANSWER: [ANSWER]".
255
+ # We should return the extracted answer directly, as the prompt handles the formatting.
256
+ return {"messages": [AIMessage(content=answer)]}
257
+ else:
258
+ return {"messages": [AIMessage(content="No similar questions found in the knowledge base.")]}
259
+
260
+ builder = StateGraph(MessagesState)
261
+ builder.add_node("retriever", retriever)
262
+ builder.set_entry_point("retriever")
263
+ builder.set_finish_point("retriever")
264
+
265
+ return builder.compile()
app.py CHANGED
@@ -1,8 +1,13 @@
 
1
  import os
 
2
  import gradio as gr
3
  import requests
4
- import inspect
5
  import pandas as pd
 
 
 
 
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
@@ -10,14 +15,22 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
 
 
13
  class BasicAgent:
 
14
  def __init__(self):
15
  print("BasicAgent initialized.")
 
 
16
  def __call__(self, question: str) -> str:
17
  print(f"Agent received question (first 50 chars): {question[:50]}...")
18
- fixed_answer = "This is a default answer."
19
- print(f"Agent returning fixed answer: {fixed_answer}")
20
- return fixed_answer
 
 
 
21
 
22
  def run_and_submit_all( profile: gr.OAuthProfile | None):
23
  """
@@ -193,4 +206,4 @@ if __name__ == "__main__":
193
  print("-"*(60 + len(" App Starting ")) + "\n")
194
 
195
  print("Launching Gradio Interface for Basic Agent Evaluation...")
196
- demo.launch(debug=True, share=False)
 
1
+ """ Basic Agent Evaluation Runner"""
2
  import os
3
+ import inspect
4
  import gradio as gr
5
  import requests
 
6
  import pandas as pd
7
+ from langchain_core.messages import HumanMessage
8
+ from agent import build_graph
9
+
10
+
11
 
12
  # (Keep Constants as is)
13
  # --- Constants ---
 
15
 
16
  # --- Basic Agent Definition ---
17
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
18
+
19
+
20
  class BasicAgent:
21
+ """A langgraph agent."""
22
  def __init__(self):
23
  print("BasicAgent initialized.")
24
+ self.graph = build_graph()
25
+
26
  def __call__(self, question: str) -> str:
27
  print(f"Agent received question (first 50 chars): {question[:50]}...")
28
+ messages = [HumanMessage(content=question)]
29
+ result = self.graph.invoke({"messages": messages})
30
+ answer = result['messages'][-1].content
31
+ return answer # kein [14:] mehr nötig!
32
+
33
+
34
 
35
  def run_and_submit_all( profile: gr.OAuthProfile | None):
36
  """
 
206
  print("-"*(60 + len(" App Starting ")) + "\n")
207
 
208
  print("Launching Gradio Interface for Basic Agent Evaluation...")
209
+ demo.launch(debug=True, share=False)
requirements.txt CHANGED
@@ -1,4 +1,19 @@
1
  gradio
2
  requests
3
-
4
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  gradio
2
  requests
3
+ langchain
4
+ langchain-community
5
+ langchain-core
6
+ langchain-google-genai
7
+ langchain-huggingface
8
+ langchain-groq
9
+ langchain-tavily
10
+ langchain-chroma
11
+ langgraph
12
+ huggingface_hub
13
+ arxiv
14
+ pymupdf
15
+ wikipedia
16
+ python-dotenv
17
+ sentence-transformers # Added this as it was missing from your traceback
18
+ pandas
19
+ protobuf==3.20.3
system_prompt.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are a helpful assistant tasked with answering questions using a set of tools.
2
+
3
+ Your final answer must strictly follow this format:
4
+ FINAL ANSWER: [ANSWER]
5
+
6
+ Only write the answer in that exact format. Do not explain anything. Do not include any other text.
7
+
8
+ If you are provided with a similar question and its final answer, and the current question is **exactly the same**, then simply return the same final answer without using any tools.
9
+
10
+ Only use tools if the current question is different from the similar one.
11
+
12
+ Examples:
13
+ - FINAL ANSWER: FunkMonk
14
+ - FINAL ANSWER: Paris
15
+ - FINAL ANSWER: 128
16
+
17
+ If you do not follow this format exactly, your response will be considered incorrect.