"""LangGraph Agent with OpenAI""" import os from langgraph.graph import START, StateGraph, MessagesState from langgraph.prebuilt import tools_condition, ToolNode from langchain_openai import ChatOpenAI from langchain_community.document_loaders import WikipediaLoader from langchain_community.document_loaders import ArxivLoader from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from langchain_core.tools import tool # Tools definition @tool def multiply(a: int, b: int) -> int: """Multiply two numbers. Args: a: first int b: second int """ return a * b @tool def add(a: int, b: int) -> int: """Add two numbers. Args: a: first int b: second int """ return a + b @tool def subtract(a: int, b: int) -> int: """Subtract two numbers. Args: a: first int b: second int """ return a - b @tool def divide(a: int, b: int) -> float: """Divide two numbers. Args: a: first int b: second int """ if b == 0: raise ValueError("Cannot divide by zero.") return a / b @tool def modulus(a: int, b: int) -> int: """Get the modulus of two numbers. Args: a: first int b: second int """ return a % b @tool def wiki_search(query: str) -> str: """Search Wikipedia for a query and return maximum 2 results. Args: query: The search query. """ try: search_docs = WikipediaLoader(query=query, load_max_docs=2).load() if not search_docs: return f"No Wikipedia results found for: {query}" formatted_search_docs = "\n\n---\n\n".join( [ f'Source: {doc.metadata.get("source", "Wikipedia")}\nContent: {doc.page_content[:2000]}...' for doc in search_docs ]) return formatted_search_docs except Exception as e: return f"Error searching Wikipedia: {str(e)}" @tool def arxiv_search(query: str) -> str: """Search Arxiv for a query and return maximum 3 results. Args: query: The search query. """ try: search_docs = ArxivLoader(query=query, load_max_docs=3).load() if not search_docs: return f"No Arxiv results found for: {query}" formatted_search_docs = "\n\n---\n\n".join( [ f'Title: {doc.metadata.get("Title", "Unknown")}\nAuthors: {doc.metadata.get("Authors", "Unknown")}\nContent: {doc.page_content[:1500]}...' for doc in search_docs ]) return formatted_search_docs except Exception as e: return f"Error searching Arxiv: {str(e)}" # System prompt system_prompt = """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: [YOUR FINAL ANSWER]. [YOUR FINAL ANSWER] should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.""" # Tools list tools = [ multiply, add, subtract, divide, modulus, wiki_search, arxiv_search, ] class LangGraphAgent: """LangGraph Agent with OpenAI that can be used in HuggingFace Space evaluation""" def __init__(self): """Initialize the agent with OpenAI LLM and tools""" print("Initializing LangGraphAgent...") # Get API key from environment self.api_key = os.environ.get("OPENAI_KEY") or os.environ.get("OPENAI_API_KEY") if not self.api_key: raise ValueError("OPENAI_KEY environment variable is required") # Initialize the graph self.graph = self._build_graph() print("LangGraphAgent initialized successfully.") def _build_graph(self): """Build the LangGraph workflow""" # Initialize OpenAI LLM llm = ChatOpenAI( model="gpt-4-turbo", # Changed from gpt-4-turbo-preview temperature=0, api_key=self.api_key ) # Bind tools to LLM llm_with_tools = llm.bind_tools(tools) # System message sys_msg = SystemMessage(content=system_prompt) # Node functions def assistant(state: MessagesState): """Assistant node""" # Ensure system message is included messages = state["messages"] if not any(isinstance(msg, SystemMessage) for msg in messages): messages = [sys_msg] + messages response = llm_with_tools.invoke(messages) return {"messages": [response]} # Build the graph builder = StateGraph(MessagesState) # Add nodes builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(tools)) # Add edges builder.add_edge(START, "assistant") builder.add_conditional_edges( "assistant", tools_condition, ) builder.add_edge("tools", "assistant") # Compile and return return builder.compile() def __call__(self, question: str) -> str: """ Process a question and return an answer. Args: question: The question to answer Returns: str: The answer to the question """ print(f"Agent received question (first 100 chars): {question[:100]}...") try: # Create message messages = [HumanMessage(content=question)] # Invoke the graph result = self.graph.invoke({"messages": messages}) # Extract the final answer ai_messages = [msg for msg in result["messages"] if isinstance(msg, AIMessage)] if ai_messages: answer = ai_messages[-1].content print(f"Agent returning answer (first 100 chars): {answer[:100]}...") return answer else: return "I couldn't generate a response. Please try again." except Exception as e: print(f"Error processing question: {e}") return f"Error: {str(e)}" # For backwards compatibility and testing BasicAgent = LangGraphAgent if __name__ == "__main__": # Test the agent print("Testing LangGraphAgent...") if not os.environ.get("OPENAI_KEY"): print("Error: OPENAI_KEY environment variable not set") print("Please set it with: export OPENAI_KEY=your-openai-api-key") exit(1) try: agent = LangGraphAgent() test_questions = [ "What is 15 * 23?", "Search Wikipedia for information about quantum computing", "What are the latest developments in AI according to recent papers on Arxiv?", ] for question in test_questions: print(f"\nQuestion: {question}") answer = agent(question) print(f"Answer: {answer}") except Exception as e: print(f"Error during testing: {e}")