Frazer2810's picture
Update agent.py
759a928 verified
raw
history blame
6.72 kB
"""LangGraph Agent – retry 5s, 30s, 60s; senza Supabase"""
import os
import time
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
# LLM providers
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
# Tools & loaders
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
load_dotenv()
# --------------------------------------------------------------------------- #
# TOOLS #
# --------------------------------------------------------------------------- #
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two integers and return the product."""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two integers and return the sum."""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract the second integer from the first and return the difference."""
return a - b
@tool
def divide(a: int, b: int) -> float:
"""Divide a by b and return the quotient (error if b == 0)."""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Return the remainder of the division of a by b."""
return a % b
@tool
def wiki_search(query: str) -> str:
"""Search Wikipedia (max 2 docs) and return formatted content."""
docs = WikipediaLoader(query=query, load_max_docs=2).load()
return "\n\n---\n\n".join(
f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
f"{d.page_content}\n</Document>"
for d in docs
)
@tool
def web_search(query: str) -> str:
"""Perform a web search with Tavily (max 3 docs) and return formatted content."""
docs = TavilySearchResults(max_results=3).invoke(query=query)
return "\n\n---\n\n".join(
f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
f"{d.page_content}\n</Document>"
for d in docs
)
@tool
def arxiv_search(query: str) -> str:
"""Search ArXiv (max 3 docs) and return first 1000 characters per paper."""
docs = ArxivLoader(query=query, load_max_docs=3).load()
return "\n\n---\n\n".join(
f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
f"{d.page_content[:1000]}\n</Document>"
for d in docs
)
# --------------------------------------------------------------------------- #
# System prompt #
# --------------------------------------------------------------------------- #
with open("system_prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read()
sys_msg = SystemMessage(content=system_prompt)
tools = [
multiply, add, subtract, divide, modulus,
wiki_search, web_search, arxiv_search,
]
# --------------------------------------------------------------------------- #
# Retry parameters #
# --------------------------------------------------------------------------- #
RETRY_DELAYS = [0, 5, 30, 60] # 4 tentativi complessivi
MAX_ATTEMPTS = len(RETRY_DELAYS)
# --------------------------------------------------------------------------- #
# Build LangGraph #
# --------------------------------------------------------------------------- #
def build_graph(provider: str = "groq"):
"""Return a LangGraph graph with custom retry logic."""
# ----------- LLM selection -------------------------------------------- #
if provider == "google":
llm_selected = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
elif provider == "groq":
llm_selected = ChatGroq(
model="qwen-qwq-32b",
temperature=0,
max_retries=0, # gestiamo noi i retry
)
elif provider == "huggingface":
llm_selected = ChatHuggingFace(
llm=HuggingFaceEndpoint(
url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
temperature=0,
)
)
else:
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
llm_with_tools = llm_selected.bind_tools(tools)
# ---------------- Retry wrapper -------------------------------------- #
def invoke_with_retry(messages):
last_err = None
for attempt, delay in enumerate(RETRY_DELAYS):
if delay:
print(f"[Retry {attempt}/{MAX_ATTEMPTS-1}] waiting {delay}s")
time.sleep(delay)
try:
return llm_with_tools.invoke(messages)
except Exception as e:
err_text = str(e)
if ("503" in err_text or "Service Unavailable" in err_text) and attempt < MAX_ATTEMPTS - 1:
last_err = e
continue # retry
raise # altro errore o tentativi finiti
raise last_err or RuntimeError("Unknown error during LLM invocation")
# ---------------- Nodes ---------------------------------------------- #
def assistant(state: MessagesState):
messages = [sys_msg] + state["messages"]
return {"messages": [invoke_with_retry(messages)]}
# ---------------- Graph ---------------------------------------------- #
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()
# --------------------------------------------------------------------------- #
# Stand-alone test #
# --------------------------------------------------------------------------- #
if __name__ == "__main__":
graph = build_graph(provider="groq")
question = (
"When was a picture of St. Thomas Aquinas first added to the Wikipedia "
"page on the Principle of double effect?"
)
msgs = [HumanMessage(content=question)]
result = graph.invoke({"messages": msgs})
for m in result["messages"]:
m.pretty_print()