Frazer2810 commited on
Commit
6589f98
·
verified ·
1 Parent(s): bc83a10

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +66 -86
agent.py CHANGED
@@ -1,21 +1,38 @@
1
- """LangGraph Agent – versione GPT-4.1 / Hugging Face Spaces"""
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
- from langgraph.prebuilt import ToolNode
7
-
8
- # LLM providers
9
- from langchain_openai import ChatOpenAI # NEW (GPT-4.1)
10
- from langchain_google_genai import ChatGoogleGenerativeAI
11
- from langchain_groq import ChatGroq
12
- from langchain_huggingface import (
13
- ChatHuggingFace,
14
- HuggingFaceEndpoint,
15
- HuggingFaceEmbeddings,
16
- )
17
 
18
- # Tools & loaders
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  from langchain_community.tools.tavily_search import TavilySearchResults
20
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
21
  from langchain_community.vectorstores import SupabaseVectorStore
@@ -24,38 +41,26 @@ from langchain_core.tools import tool
24
  from langchain.tools.retriever import create_retriever_tool
25
  from supabase.client import Client, create_client
26
 
27
- # --------------------------------------------------------------------------- #
28
- # Carica variabili d'ambiente (.env eventuale + secrets di HF Spaces) #
29
- # --------------------------------------------------------------------------- #
30
- load_dotenv() # nei Spaces le secrets sono già in os.environ
31
 
32
- # --------------------------------------------------------------------------- #
33
- # TOOL di esempio (aritmetica) #
34
- # --------------------------------------------------------------------------- #
35
  @tool
36
  def multiply(a: int, b: int) -> int: return a * b
37
-
38
  @tool
39
  def add(a: int, b: int) -> int: return a + b
40
-
41
  @tool
42
  def subtract(a: int, b: int) -> int: return a - b
43
-
44
  @tool
45
  def divide(a: int, b: int) -> float:
46
  if b == 0:
47
  raise ValueError("Cannot divide by zero.")
48
  return a / b
49
-
50
  @tool
51
  def modulus(a: int, b: int) -> int: return a % b
52
 
53
- # --------------------------------------------------------------------------- #
54
- # TOOL: Wikipedia #
55
- # --------------------------------------------------------------------------- #
56
  @tool
57
  def wiki_search(query: str) -> str:
58
- """Search Wikipedia (max 2 docs) and return formatted result."""
59
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
60
  return "\n\n---\n\n".join(
61
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
@@ -63,12 +68,9 @@ def wiki_search(query: str) -> str:
63
  for d in docs
64
  )
65
 
66
- # --------------------------------------------------------------------------- #
67
- # TOOL: Tavily web search #
68
- # --------------------------------------------------------------------------- #
69
  @tool
70
  def web_search(query: str) -> str:
71
- """Search Tavily (max 3 docs) and return formatted result."""
72
  docs = TavilySearchResults(max_results=3).invoke(query=query)
73
  return "\n\n---\n\n".join(
74
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
@@ -76,12 +78,9 @@ def web_search(query: str) -> str:
76
  for d in docs
77
  )
78
 
79
- # --------------------------------------------------------------------------- #
80
- # TOOL: ArXiv #
81
- # --------------------------------------------------------------------------- #
82
  @tool
83
  def arxiv_search(query: str) -> str:
84
- """Search ArXiv (max 3 docs) and return formatted snippet."""
85
  docs = ArxivLoader(query=query, load_max_docs=3).load()
86
  return "\n\n---\n\n".join(
87
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
@@ -90,14 +89,14 @@ def arxiv_search(query: str) -> str:
90
  )
91
 
92
  # --------------------------------------------------------------------------- #
93
- # System prompt #
94
  # --------------------------------------------------------------------------- #
95
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
96
  system_prompt = f.read()
97
  sys_msg = SystemMessage(content=system_prompt)
98
 
99
  # --------------------------------------------------------------------------- #
100
- # Vector store per il retriever #
101
  # --------------------------------------------------------------------------- #
102
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
103
  supabase: Client = create_client(
@@ -117,49 +116,38 @@ question_search_tool = create_retriever_tool(
117
  )
118
 
119
  # --------------------------------------------------------------------------- #
120
- # Registrazione tool list #
121
  # --------------------------------------------------------------------------- #
122
  tools = [
123
- multiply,
124
- add,
125
- subtract,
126
- divide,
127
- modulus,
128
- wiki_search,
129
- web_search,
130
- arxiv_search,
131
  question_search_tool,
132
  ]
133
 
134
  # --------------------------------------------------------------------------- #
135
- # Costruzione del graph LangGraph #
136
  # --------------------------------------------------------------------------- #
137
  def build_graph(provider: str = "openai"):
138
- """Restituisce un graph LangGraph pronto all'uso.
139
-
140
- provider: "openai" (default), "google", "groq", "huggingface"
141
- """
142
- # --- Selezione LLM ------------------------------------------------------ #
143
  if provider == "openai":
144
- openai_key = os.getenv("OPENAI_KEY")
145
- if not openai_key:
146
- raise ValueError(
147
- "❌ Environment variable OPENAI_KEY mancante. "
148
- "Aggiungi la secret dal tab 'Secrets' dello Space."
149
- )
150
- llm = ChatOpenAI(
151
- model_name="gpt-4.1",
152
- temperature=0,
153
- openai_api_key=openai_key,
154
- )
155
 
156
  elif provider == "google":
 
 
157
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
158
 
159
  elif provider == "groq":
 
 
160
  llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
161
 
162
  elif provider == "huggingface":
 
 
163
  llm = ChatHuggingFace(
164
  llm=HuggingFaceEndpoint(
165
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
@@ -167,32 +155,25 @@ def build_graph(provider: str = "openai"):
167
  )
168
  )
169
  else:
170
- raise ValueError(
171
- "Invalid provider. Choose 'openai', 'google', 'groq' or 'huggingface'."
172
- )
173
 
174
- # Abilita tool calling
175
  llm_with_tools = llm.bind_tools(tools)
176
 
177
- # ------------------------- NODES --------------------------------------- #
178
  def assistant(state: MessagesState):
179
- """Invoca il modello."""
180
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
181
 
182
  def retriever(state: MessagesState):
183
- """Aggiunge alla history un Q/A simile come esempio."""
184
  similar = vector_store.similarity_search(state["messages"][0].content)
185
  if similar:
186
- example_msg = HumanMessage(
187
- content=(
188
- "Here I provide a similar question and answer for reference:\n\n"
189
- f"{similar[0].page_content}"
190
- )
191
  )
192
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
193
  return {"messages": [sys_msg] + state["messages"]}
194
 
195
- # --------------------------- GRAPH ------------------------------------- #
196
  builder = StateGraph(MessagesState)
197
  builder.add_node("retriever", retriever)
198
  builder.add_node("assistant", assistant)
@@ -205,14 +186,13 @@ def build_graph(provider: str = "openai"):
205
 
206
  return builder.compile()
207
 
208
-
209
  # --------------------------------------------------------------------------- #
210
- # Quick test (python agent.py) #
211
  # --------------------------------------------------------------------------- #
212
  if __name__ == "__main__":
213
- graph = build_graph(provider="openai")
214
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
215
- msgs = [HumanMessage(content=question)]
216
- result = graph.invoke({"messages": msgs})
217
- for m in result["messages"]:
218
  m.pretty_print()
 
1
+ """LangGraph Agent – GPT-4.1 / Hugging Face Spaces (import lazy)"""
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition, ToolNode
6
+ from langchain_openai import ChatOpenAI
7
+
8
+ # --------------------------------------------------------------------------- #
9
+ # Import facoltativi (se il pacchetto non c'è, il provider viene disattivato) #
10
+ # --------------------------------------------------------------------------- #
11
+ def _lazy_import(name):
12
+ try:
13
+ module = __import__(name, fromlist=["*"])
14
+ return module
15
+ except ModuleNotFoundError:
16
+ return None
17
 
18
+ lg_google = _lazy_import("langchain_google_genai")
19
+ lg_groq = _lazy_import("langchain_groq")
20
+ lg_hf = _lazy_import("langchain_huggingface")
21
+
22
+ if lg_google:
23
+ ChatGoogleGenerativeAI = lg_google.ChatGoogleGenerativeAI
24
+ if lg_groq:
25
+ ChatGroq = lg_groq.ChatGroq
26
+ if lg_hf:
27
+ ChatHuggingFace = lg_hf.ChatHuggingFace
28
+ HuggingFaceEndpoint = lg_hf.HuggingFaceEndpoint
29
+ HuggingFaceEmbeddings = lg_hf.HuggingFaceEmbeddings
30
+ else:
31
+ from langchain_huggingface import HuggingFaceEmbeddings # solo embeddings
32
+
33
+ # --------------------------------------------------------------------------- #
34
+ # Tools & loaders #
35
+ # --------------------------------------------------------------------------- #
36
  from langchain_community.tools.tavily_search import TavilySearchResults
37
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
38
  from langchain_community.vectorstores import SupabaseVectorStore
 
41
  from langchain.tools.retriever import create_retriever_tool
42
  from supabase.client import Client, create_client
43
 
44
+ load_dotenv() # Secrets di HF Spaces
 
 
 
45
 
46
+ # -------------------- TOOL di esempio -------------------- #
 
 
47
  @tool
48
  def multiply(a: int, b: int) -> int: return a * b
 
49
  @tool
50
  def add(a: int, b: int) -> int: return a + b
 
51
  @tool
52
  def subtract(a: int, b: int) -> int: return a - b
 
53
  @tool
54
  def divide(a: int, b: int) -> float:
55
  if b == 0:
56
  raise ValueError("Cannot divide by zero.")
57
  return a / b
 
58
  @tool
59
  def modulus(a: int, b: int) -> int: return a % b
60
 
61
+ # -------------------- Wikipedia -------------------------- #
 
 
62
  @tool
63
  def wiki_search(query: str) -> str:
 
64
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
65
  return "\n\n---\n\n".join(
66
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
 
68
  for d in docs
69
  )
70
 
71
+ # -------------------- Tavily ----------------------------- #
 
 
72
  @tool
73
  def web_search(query: str) -> str:
 
74
  docs = TavilySearchResults(max_results=3).invoke(query=query)
75
  return "\n\n---\n\n".join(
76
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
 
78
  for d in docs
79
  )
80
 
81
+ # -------------------- ArXiv ------------------------------ #
 
 
82
  @tool
83
  def arxiv_search(query: str) -> str:
 
84
  docs = ArxivLoader(query=query, load_max_docs=3).load()
85
  return "\n\n---\n\n".join(
86
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
 
89
  )
90
 
91
  # --------------------------------------------------------------------------- #
92
+ # System prompt #
93
  # --------------------------------------------------------------------------- #
94
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
95
  system_prompt = f.read()
96
  sys_msg = SystemMessage(content=system_prompt)
97
 
98
  # --------------------------------------------------------------------------- #
99
+ # Vector store / retriever #
100
  # --------------------------------------------------------------------------- #
101
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
102
  supabase: Client = create_client(
 
116
  )
117
 
118
  # --------------------------------------------------------------------------- #
119
+ # Lista tool #
120
  # --------------------------------------------------------------------------- #
121
  tools = [
122
+ multiply, add, subtract, divide, modulus,
123
+ wiki_search, web_search, arxiv_search,
 
 
 
 
 
 
124
  question_search_tool,
125
  ]
126
 
127
  # --------------------------------------------------------------------------- #
128
+ # Costruzione graph #
129
  # --------------------------------------------------------------------------- #
130
  def build_graph(provider: str = "openai"):
131
+ # ------------------- LLM selection ------------------------------------- #
 
 
 
 
132
  if provider == "openai":
133
+ key = os.getenv("OPENAI_KEY")
134
+ if not key:
135
+ raise ValueError("OPENAI_KEY mancante: aggiungi la secret nello Space.")
136
+ llm = ChatOpenAI(model_name="gpt-4.1", temperature=0, openai_api_key=key)
 
 
 
 
 
 
 
137
 
138
  elif provider == "google":
139
+ if not lg_google:
140
+ raise ImportError("langchain_google_genai non installato.")
141
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
142
 
143
  elif provider == "groq":
144
+ if not lg_groq:
145
+ raise ImportError("langchain_groq non installato.")
146
  llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
147
 
148
  elif provider == "huggingface":
149
+ if not lg_hf:
150
+ raise ImportError("langchain_huggingface non installato.")
151
  llm = ChatHuggingFace(
152
  llm=HuggingFaceEndpoint(
153
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
 
155
  )
156
  )
157
  else:
158
+ raise ValueError("Provider non valido.")
 
 
159
 
 
160
  llm_with_tools = llm.bind_tools(tools)
161
 
162
+ # ------------------- Nodes -------------------------------------------- #
163
  def assistant(state: MessagesState):
 
164
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
165
 
166
  def retriever(state: MessagesState):
 
167
  similar = vector_store.similarity_search(state["messages"][0].content)
168
  if similar:
169
+ example = HumanMessage(
170
+ content=("Here I provide a similar question and answer for reference:\n\n"
171
+ f"{similar[0].page_content}")
 
 
172
  )
173
+ return {"messages": [sys_msg] + state["messages"] + [example]}
174
  return {"messages": [sys_msg] + state["messages"]}
175
 
176
+ # ------------------- Graph -------------------------------------------- #
177
  builder = StateGraph(MessagesState)
178
  builder.add_node("retriever", retriever)
179
  builder.add_node("assistant", assistant)
 
186
 
187
  return builder.compile()
188
 
 
189
  # --------------------------------------------------------------------------- #
190
+ # Test rapido #
191
  # --------------------------------------------------------------------------- #
192
  if __name__ == "__main__":
193
+ g = build_graph()
194
+ q = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
195
+ msgs = [HumanMessage(content=q)]
196
+ res = g.invoke({"messages": msgs})
197
+ for m in res["messages"]:
198
  m.pretty_print()