| """LangGraph Agent with OpenAI""" |
| import os |
| from langgraph.graph import START, StateGraph, MessagesState |
| from langgraph.prebuilt import tools_condition, ToolNode |
| from langchain_openai import ChatOpenAI |
| from langchain_community.document_loaders import WikipediaLoader |
| from langchain_community.document_loaders import ArxivLoader |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage |
| from langchain_core.tools import tool |
|
|
| |
| @tool |
| def multiply(a: int, b: int) -> int: |
| """Multiply two numbers. |
| |
| Args: |
| a: first int |
| b: second int |
| """ |
| return a * b |
|
|
| @tool |
| def add(a: int, b: int) -> int: |
| """Add two numbers. |
| |
| Args: |
| a: first int |
| b: second int |
| """ |
| return a + b |
|
|
| @tool |
| def subtract(a: int, b: int) -> int: |
| """Subtract two numbers. |
| |
| Args: |
| a: first int |
| b: second int |
| """ |
| return a - b |
|
|
| @tool |
| def divide(a: int, b: int) -> float: |
| """Divide two numbers. |
| |
| Args: |
| a: first int |
| b: second int |
| """ |
| if b == 0: |
| raise ValueError("Cannot divide by zero.") |
| return a / b |
|
|
| @tool |
| def modulus(a: int, b: int) -> int: |
| """Get the modulus of two numbers. |
| |
| Args: |
| a: first int |
| b: second int |
| """ |
| return a % b |
|
|
| @tool |
| def wiki_search(query: str) -> str: |
| """Search Wikipedia for a query and return maximum 2 results. |
| |
| Args: |
| query: The search query. |
| """ |
| try: |
| search_docs = WikipediaLoader(query=query, load_max_docs=2).load() |
| if not search_docs: |
| return f"No Wikipedia results found for: {query}" |
| |
| formatted_search_docs = "\n\n---\n\n".join( |
| [ |
| f'Source: {doc.metadata.get("source", "Wikipedia")}\nContent: {doc.page_content[:2000]}...' |
| for doc in search_docs |
| ]) |
| return formatted_search_docs |
| except Exception as e: |
| return f"Error searching Wikipedia: {str(e)}" |
|
|
|
|
|
|
| @tool |
| def arxiv_search(query: str) -> str: |
| """Search Arxiv for a query and return maximum 3 results. |
| |
| Args: |
| query: The search query. |
| """ |
| try: |
| search_docs = ArxivLoader(query=query, load_max_docs=3).load() |
| if not search_docs: |
| return f"No Arxiv results found for: {query}" |
| |
| formatted_search_docs = "\n\n---\n\n".join( |
| [ |
| f'Title: {doc.metadata.get("Title", "Unknown")}\nAuthors: {doc.metadata.get("Authors", "Unknown")}\nContent: {doc.page_content[:1500]}...' |
| for doc in search_docs |
| ]) |
| return formatted_search_docs |
| except Exception as e: |
| return f"Error searching Arxiv: {str(e)}" |
|
|
| |
| system_prompt = """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: [YOUR FINAL ANSWER]. [YOUR FINAL ANSWER] should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.""" |
|
|
| |
| tools = [ |
| multiply, |
| add, |
| subtract, |
| divide, |
| modulus, |
| wiki_search, |
| arxiv_search, |
| ] |
|
|
| class LangGraphAgent: |
| """LangGraph Agent with OpenAI that can be used in HuggingFace Space evaluation""" |
| |
| def __init__(self): |
| """Initialize the agent with OpenAI LLM and tools""" |
| print("Initializing LangGraphAgent...") |
| |
| |
| self.api_key = os.environ.get("OPENAI_KEY") or os.environ.get("OPENAI_API_KEY") |
| if not self.api_key: |
| raise ValueError("OPENAI_KEY environment variable is required") |
| |
| |
| self.graph = self._build_graph() |
| print("LangGraphAgent initialized successfully.") |
| |
| def _build_graph(self): |
| """Build the LangGraph workflow""" |
| |
| llm = ChatOpenAI( |
| model="gpt-4-turbo", |
| temperature=0, |
| api_key=self.api_key |
| ) |
| |
| |
| llm_with_tools = llm.bind_tools(tools) |
| |
| |
| sys_msg = SystemMessage(content=system_prompt) |
| |
| |
| def assistant(state: MessagesState): |
| """Assistant node""" |
| |
| messages = state["messages"] |
| if not any(isinstance(msg, SystemMessage) for msg in messages): |
| messages = [sys_msg] + messages |
| |
| response = llm_with_tools.invoke(messages) |
| return {"messages": [response]} |
| |
| |
| builder = StateGraph(MessagesState) |
| |
| |
| builder.add_node("assistant", assistant) |
| builder.add_node("tools", ToolNode(tools)) |
| |
| |
| builder.add_edge(START, "assistant") |
| builder.add_conditional_edges( |
| "assistant", |
| tools_condition, |
| ) |
| builder.add_edge("tools", "assistant") |
| |
| |
| return builder.compile() |
| |
| def __call__(self, question: str) -> str: |
| """ |
| Process a question and return an answer. |
| |
| Args: |
| question: The question to answer |
| |
| Returns: |
| str: The answer to the question |
| """ |
| print(f"Agent received question (first 100 chars): {question[:100]}...") |
| |
| try: |
| |
| messages = [HumanMessage(content=question)] |
| |
| |
| result = self.graph.invoke({"messages": messages}) |
| |
| |
| ai_messages = [msg for msg in result["messages"] if isinstance(msg, AIMessage)] |
| |
| if ai_messages: |
| answer = ai_messages[-1].content |
| print(f"Agent returning answer (first 100 chars): {answer[:100]}...") |
| return answer |
| else: |
| return "I couldn't generate a response. Please try again." |
| |
| except Exception as e: |
| print(f"Error processing question: {e}") |
| return f"Error: {str(e)}" |
|
|
| |
| BasicAgent = LangGraphAgent |
|
|
| if __name__ == "__main__": |
| |
| print("Testing LangGraphAgent...") |
| if not os.environ.get("OPENAI_KEY"): |
| print("Error: OPENAI_KEY environment variable not set") |
| print("Please set it with: export OPENAI_KEY=your-openai-api-key") |
| exit(1) |
| |
| try: |
| agent = LangGraphAgent() |
| test_questions = [ |
| "What is 15 * 23?", |
| "Search Wikipedia for information about quantum computing", |
| "What are the latest developments in AI according to recent papers on Arxiv?", |
| ] |
| |
| for question in test_questions: |
| print(f"\nQuestion: {question}") |
| answer = agent(question) |
| print(f"Answer: {answer}") |
| |
| except Exception as e: |
| print(f"Error during testing: {e}") |