Frazer2810's picture
Update agent.py
06e4fd4 verified
raw
history blame
7.66 kB
"""LangGraph Agent with OpenAI"""
import os
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_openai import ChatOpenAI
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.document_loaders import ArxivLoader
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.tools import tool
# Tools definition
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two numbers.
Args:
a: first int
b: second int
"""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two numbers.
Args:
a: first int
b: second int
"""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract two numbers.
Args:
a: first int
b: second int
"""
return a - b
@tool
def divide(a: int, b: int) -> float:
"""Divide two numbers.
Args:
a: first int
b: second int
"""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Get the modulus of two numbers.
Args:
a: first int
b: second int
"""
return a % b
@tool
def wiki_search(query: str) -> str:
"""Search Wikipedia for a query and return maximum 2 results.
Args:
query: The search query.
"""
try:
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
if not search_docs:
return f"No Wikipedia results found for: {query}"
formatted_search_docs = "\n\n---\n\n".join(
[
f'Source: {doc.metadata.get("source", "Wikipedia")}\nContent: {doc.page_content[:2000]}...'
for doc in search_docs
])
return formatted_search_docs
except Exception as e:
return f"Error searching Wikipedia: {str(e)}"
@tool
def arxiv_search(query: str) -> str:
"""Search Arxiv for a query and return maximum 3 results.
Args:
query: The search query.
"""
try:
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
if not search_docs:
return f"No Arxiv results found for: {query}"
formatted_search_docs = "\n\n---\n\n".join(
[
f'Title: {doc.metadata.get("Title", "Unknown")}\nAuthors: {doc.metadata.get("Authors", "Unknown")}\nContent: {doc.page_content[:1500]}...'
for doc in search_docs
])
return formatted_search_docs
except Exception as e:
return f"Error searching Arxiv: {str(e)}"
# System prompt
system_prompt = """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: [YOUR FINAL ANSWER]. [YOUR FINAL ANSWER] should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""
# Tools list
tools = [
multiply,
add,
subtract,
divide,
modulus,
wiki_search,
arxiv_search,
]
class LangGraphAgent:
"""LangGraph Agent with OpenAI that can be used in HuggingFace Space evaluation"""
def __init__(self):
"""Initialize the agent with OpenAI LLM and tools"""
print("Initializing LangGraphAgent...")
# Get API key from environment
self.api_key = os.environ.get("OPENAI_KEY") or os.environ.get("OPENAI_API_KEY")
if not self.api_key:
raise ValueError("OPENAI_KEY environment variable is required")
# Initialize the graph
self.graph = self._build_graph()
print("LangGraphAgent initialized successfully.")
def _build_graph(self):
"""Build the LangGraph workflow"""
# Initialize OpenAI LLM
llm = ChatOpenAI(
model="gpt-4-turbo", # Changed from gpt-4-turbo-preview
temperature=0,
api_key=self.api_key
)
# Bind tools to LLM
llm_with_tools = llm.bind_tools(tools)
# System message
sys_msg = SystemMessage(content=system_prompt)
# Node functions
def assistant(state: MessagesState):
"""Assistant node"""
# Ensure system message is included
messages = state["messages"]
if not any(isinstance(msg, SystemMessage) for msg in messages):
messages = [sys_msg] + messages
response = llm_with_tools.invoke(messages)
return {"messages": [response]}
# Build the graph
builder = StateGraph(MessagesState)
# Add nodes
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
# Add edges
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
"assistant",
tools_condition,
)
builder.add_edge("tools", "assistant")
# Compile and return
return builder.compile()
def __call__(self, question: str) -> str:
"""
Process a question and return an answer.
Args:
question: The question to answer
Returns:
str: The answer to the question
"""
print(f"Agent received question (first 100 chars): {question[:100]}...")
try:
# Create message
messages = [HumanMessage(content=question)]
# Invoke the graph
result = self.graph.invoke({"messages": messages})
# Extract the final answer
ai_messages = [msg for msg in result["messages"] if isinstance(msg, AIMessage)]
if ai_messages:
answer = ai_messages[-1].content
print(f"Agent returning answer (first 100 chars): {answer[:100]}...")
return answer
else:
return "I couldn't generate a response. Please try again."
except Exception as e:
print(f"Error processing question: {e}")
return f"Error: {str(e)}"
# For backwards compatibility and testing
BasicAgent = LangGraphAgent
if __name__ == "__main__":
# Test the agent
print("Testing LangGraphAgent...")
if not os.environ.get("OPENAI_KEY"):
print("Error: OPENAI_KEY environment variable not set")
print("Please set it with: export OPENAI_KEY=your-openai-api-key")
exit(1)
try:
agent = LangGraphAgent()
test_questions = [
"What is 15 * 23?",
"Search Wikipedia for information about quantum computing",
"What are the latest developments in AI according to recent papers on Arxiv?",
]
for question in test_questions:
print(f"\nQuestion: {question}")
answer = agent(question)
print(f"Answer: {answer}")
except Exception as e:
print(f"Error during testing: {e}")