peachchange commited on
Commit
42a7415
·
verified ·
1 Parent(s): 72a4ffb

Upload 5 files

Browse files
Files changed (5) hide show
  1. __init__.py +0 -0
  2. app_langgraph.py +101 -0
  3. math_tools.py +52 -0
  4. multimodal_tools.py +177 -0
  5. search_tools.py +53 -0
__init__.py ADDED
File without changes
app_langgraph.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent"""
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace, HuggingFaceEmbeddings
8
+ from langchain_core.messages import SystemMessage, HumanMessage
9
+ from langchain_core.globals import set_debug
10
+ from langchain_groq import ChatGroq
11
+ from tools.search_tools import web_search, arvix_search, wiki_search
12
+ from tools.math_tools import multiply, add, subtract, divide
13
+ # from supabase.client import Client, create_client
14
+ # from langchain.tools.retriever import create_retriever_tool
15
+ # from langchain_community.vectorstores import SupabaseVectorStore
16
+ import json
17
+ from tools.multimodal_tools import extract_text, analyze_image_tool, analyze_audio_tool
18
+ from langchain_google_genai import ChatGoogleGenerativeAI
19
+
20
+ # set_debug(True)
21
+ load_dotenv()
22
+
23
+ tools = [
24
+ multiply,
25
+ add,
26
+ subtract,
27
+ divide,
28
+ web_search,
29
+ wiki_search,
30
+ arvix_search,
31
+ extract_text,
32
+ analyze_image_tool,
33
+ analyze_audio_tool
34
+ ]
35
+
36
+ def build_graph():
37
+ hf_token = os.getenv("HF_TOKEN")
38
+ api_key = os.getenv("GEMINI_API_KEY")
39
+ # llm = HuggingFaceEndpoint(
40
+ # repo_id="Qwen/Qwen2.5-Coder-32B-Instruct",
41
+ # huggingfacehub_api_token=hf_token,
42
+ # )
43
+
44
+ # chat = ChatHuggingFace(llm=llm, verbose=True)
45
+ # llm_with_tools = chat.bind_tools(tools)
46
+
47
+ # llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
48
+ # llm_with_tools = llm.bind_tools(tools)
49
+
50
+ chat = ChatGoogleGenerativeAI(
51
+ model= "gemini-2.5-pro-preview-05-06",
52
+ temperature=0,
53
+ max_retries=2,
54
+ google_api_key=api_key,
55
+ thinking_budget= 0
56
+ )
57
+ chat_with_tools = chat.bind_tools(tools)
58
+
59
+ def assistant(state: MessagesState):
60
+ sys_msg = "You are a helpful assistant with access to tools. Understand user requests accurately. Use your tools when needed to answer effectively. Strictly follow all user instructions and constraints." \
61
+ "Pay attention: your output needs to contain only the final answer without any reasoning since it will be strictly evaluated against a dataset which contains only the specific response." \
62
+ "Your final output needs to be just the string or integer containing the answer, not an array or technical stuff."
63
+ return {
64
+ "messages": [chat_with_tools.invoke([sys_msg] + state["messages"])],
65
+ }
66
+
67
+ ## The graph
68
+ builder = StateGraph(MessagesState)
69
+
70
+ builder.add_node("assistant", assistant)
71
+ builder.add_node("tools", ToolNode(tools))
72
+
73
+ builder.add_edge(START, "assistant")
74
+ builder.add_conditional_edges(
75
+ "assistant",
76
+ # If the latest message requires a tool, route to tools
77
+ # Otherwise, provide a direct response
78
+ tools_condition,
79
+ )
80
+ builder.add_edge("tools", "assistant")
81
+ return builder.compile()
82
+
83
+ # test
84
+ if __name__ == "__main__":
85
+
86
+ graph = build_graph()
87
+ with open('sample.jsonl', 'r') as jsonl_file:
88
+ json_list = list(jsonl_file)
89
+
90
+ start = 10 #revisit 5, 8,
91
+ end = start + 1
92
+ for json_str in json_list[start:end]:
93
+ json_data = json.loads(json_str)
94
+ print(f"Question::::::::: {json_data['Question']}")
95
+ print(f"Final answer::::: {json_data['Final answer']}")
96
+
97
+ question = json_data['Question']
98
+ messages = [HumanMessage(content=question)]
99
+ messages = graph.invoke({"messages": messages})
100
+ for m in messages["messages"]:
101
+ m.pretty_print()
math_tools.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.tools import tool
2
+
3
+ @tool
4
+ def multiply(a: int, b: int) -> int:
5
+ """Multiply two numbers.
6
+ Args:
7
+ a: first int
8
+ b: second int
9
+ """
10
+ return a * b
11
+
12
+ @tool
13
+ def add(a: int, b: int) -> int:
14
+ """Add two numbers.
15
+
16
+ Args:
17
+ a: first int
18
+ b: second int
19
+ """
20
+ return a + b
21
+
22
+ @tool
23
+ def subtract(a: int, b: int) -> int:
24
+ """Subtract two numbers.
25
+
26
+ Args:
27
+ a: first int
28
+ b: second int
29
+ """
30
+ return a - b
31
+
32
+ @tool
33
+ def divide(a: int, b: int) -> int:
34
+ """Divide two numbers.
35
+
36
+ Args:
37
+ a: first int
38
+ b: second int
39
+ """
40
+ if b == 0:
41
+ raise ValueError("Cannot divide by zero.")
42
+ return a / b
43
+
44
+ @tool
45
+ def modulus(a: int, b: int) -> int:
46
+ """Get the modulus of two numbers.
47
+
48
+ Args:
49
+ a: first int
50
+ b: second int
51
+ """
52
+ return a % b
multimodal_tools.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
4
+ from langchain_google_genai import ChatGoogleGenerativeAI
5
+ from langchain.tools import Tool
6
+ from langchain_core.tools import tool
7
+
8
+ api_key = os.getenv("GEMINI_API_KEY")
9
+
10
+ # Create LLM class
11
+ vision_llm = ChatGoogleGenerativeAI(
12
+ model= "gemini-2.5-flash-preview-05-20",
13
+ temperature=0,
14
+ max_retries=2,
15
+ google_api_key=api_key
16
+ )
17
+
18
+ @tool("extract_text_tool", parse_docstring=True)
19
+ def extract_text(img_path: str) -> str:
20
+ """Extract text from an image file using a multimodal model.
21
+
22
+ Args:
23
+ img_path (str): The path to the image file from which to extract text.
24
+
25
+ Returns:
26
+ str: The extracted text from the image, or an empty string if an error occurs.
27
+ """
28
+ all_text = ""
29
+ try:
30
+ # Read image and encode as base64
31
+ with open(img_path, "rb") as image_file:
32
+ image_bytes = image_file.read()
33
+
34
+ image_base64 = base64.b64encode(image_bytes).decode("utf-8")
35
+
36
+ # Prepare the prompt including the base64 image data
37
+ message = [
38
+ HumanMessage(
39
+ content=[
40
+ {
41
+ "type": "text",
42
+ "text": (
43
+ "Extract all the text from this image. "
44
+ "Return only the extracted text, no explanations."
45
+ ),
46
+ },
47
+ {
48
+ "type": "image_url",
49
+ "image_url": {
50
+ "url": f"data:image/png;base64,{image_base64}"
51
+ },
52
+ },
53
+ ]
54
+ )
55
+ ]
56
+
57
+ # Call the vision-capable model
58
+ response = vision_llm.invoke(message)
59
+
60
+ # Append extracted text
61
+ all_text += response.content + "\n\n"
62
+
63
+ return all_text.strip()
64
+ except Exception as e:
65
+ # A butler should handle errors gracefully
66
+ error_msg = f"Error extracting text: {str(e)}"
67
+ print(error_msg)
68
+ return ""
69
+
70
+ @tool("analyze_image_tool", parse_docstring=True)
71
+ def analyze_image_tool(user_query: str, img_path: str) -> str:
72
+ """Answer the question reasoning on the image.
73
+
74
+ Args:
75
+ user_query (str): The question to be answered based on the image.
76
+ img_path (str): Path to the image file to be analyzed.
77
+
78
+ Returns:
79
+ str: The answer to the query based on image content, or an empty string if an error occurs.
80
+ """
81
+ all_text = ""
82
+ try:
83
+ # Read image and encode as base64
84
+ with open(img_path, "rb") as image_file:
85
+ image_bytes = image_file.read()
86
+
87
+ image_base64 = base64.b64encode(image_bytes).decode("utf-8")
88
+
89
+ # Prepare the prompt including the base64 image data
90
+ message = [
91
+ HumanMessage(
92
+ content=[
93
+ {
94
+ "type": "text",
95
+ "text": (
96
+ f"User query: {user_query}"
97
+ ),
98
+ },
99
+ {
100
+ "type": "image_url",
101
+ "image_url": {
102
+ "url": f"data:image/png;base64,{image_base64}"
103
+ },
104
+ },
105
+ ]
106
+ )
107
+ ]
108
+
109
+ # Call the vision-capable model
110
+ response = vision_llm.invoke(message)
111
+
112
+ # Append extracted text
113
+ all_text += response.content + "\n\n"
114
+
115
+ return all_text.strip()
116
+ except Exception as e:
117
+ # A butler should handle errors gracefully
118
+ error_msg = f"Error analyzing image: {str(e)}"
119
+ print(error_msg)
120
+ return ""
121
+
122
+ @tool("analyze_audio_tool", parse_docstring=True)
123
+ def analyze_audio_tool(user_query: str, audio_path: str) -> str:
124
+ """Answer the question by reasoning on the provided audio file.
125
+
126
+ Args:
127
+ user_query (str): The question to be answered based on the audio content.
128
+ audio_path (str): Path to the audio file (e.g., .mp3, .wav, .flac, .aac, .ogg).
129
+
130
+ Returns:
131
+ str: The answer to the query based on audio content, or an error message/empty string if an error occurs.
132
+ """
133
+ try:
134
+ # Determine MIME type from file extension
135
+ _filename, file_extension = os.path.splitext(audio_path)
136
+ file_extension = file_extension.lower()
137
+
138
+ supported_formats = {
139
+ ".mp3": "audio/mp3", ".wav": "audio/wav", ".flac": "audio/flac",
140
+ ".aac": "audio/aac", ".ogg": "audio/ogg"
141
+ }
142
+
143
+ if file_extension not in supported_formats:
144
+ return (f"Error: Unsupported audio file format '{file_extension}'. "
145
+ f"Supported extensions: {', '.join(supported_formats.keys())}.")
146
+ mime_type = supported_formats[file_extension]
147
+
148
+ # Read audio file and encode as base64
149
+ with open(audio_path, "rb") as audio_file:
150
+ audio_bytes = audio_file.read()
151
+ audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
152
+
153
+ # Prepare the prompt including the base64 audio data
154
+ message = [
155
+ HumanMessage(
156
+ content=[
157
+ {
158
+ "type": "text",
159
+ "text": f"User query: {user_query}",
160
+ },
161
+ {
162
+ "type": "audio",
163
+ "source_type": "base64",
164
+ "mime_type": mime_type,
165
+ "data": audio_base64
166
+ },
167
+ ]
168
+ )
169
+ ]
170
+
171
+ # Call the vision-capable model
172
+ response = vision_llm.invoke(message)
173
+ return response.content.strip()
174
+ except Exception as e:
175
+ error_msg = f"Error analyzing audio: {str(e)}"
176
+ print(error_msg)
177
+ return ""
search_tools.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.tools import tool
2
+ from langchain_community.document_loaders import WikipediaLoader
3
+ from langchain_community.document_loaders import ArxivLoader
4
+ # Search engine specifically for LLMs
5
+ # from langchain_community.tools.tavily_search import TavilySearchResults
6
+ from langchain_tavily import TavilySearch
7
+
8
+
9
+ @tool
10
+ def web_search(query: str) -> str:
11
+ """Search Tavily for a query and return maximum 3 results.
12
+
13
+ Args:
14
+ query: The search query."""
15
+ # print(f"Web search query:::::::::::: {query}")
16
+ search_docs = TavilySearch(max_results=3).invoke({"query":query})
17
+ formatted_search_docs = "\n\n---\n\n".join(
18
+ [
19
+ f'<Document source="{doc["url"]}" page="{doc["title"]}"/>\n{doc["content"]}\n</Document>'
20
+ for doc in search_docs['results']
21
+ ])
22
+ # print(f"Web search result:::::::::::: {formatted_search_docs}")
23
+ return {"web_results": formatted_search_docs}
24
+
25
+ @tool
26
+ def wiki_search(query: str) -> str:
27
+ """Search Wikipedia for a query and return maximum 2 results.
28
+
29
+ Args:
30
+ query: The search query."""
31
+
32
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
33
+ formatted_search_docs = "\n\n---\n\n".join(
34
+ [
35
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
36
+ for doc in search_docs
37
+ ])
38
+
39
+ return {"wiki_results": formatted_search_docs}
40
+
41
+ @tool
42
+ def arvix_search(query: str) -> str:
43
+ """Search Arxiv for a query and return maximum 3 result.
44
+
45
+ Args:
46
+ query: The search query."""
47
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
48
+ formatted_search_docs = "\n\n---\n\n".join(
49
+ [
50
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
51
+ for doc in search_docs
52
+ ])
53
+ return {"arvix_results": formatted_search_docs}