Nitien commited on
Commit
c9cad9e
·
verified ·
1 Parent(s): 81917a3

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.py +52 -0
  2. customtools.py +261 -0
  3. gaia_agent.py +386 -0
config.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration and constants for the GAIA agent.
3
+ Centralized configuration for easy management and customization.
4
+ """
5
+
6
+ import os
7
+ from dotenv import load_dotenv
8
+
9
+ load_dotenv()
10
+
11
+ # ==================== API KEYS ====================
12
+ OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", "")
13
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY", "")
14
+
15
+ # ==================== LLM CONFIGURATION ====================
16
+ LLM_MODEL = "inclusionai/Ling-2.6-1T:free"
17
+ LLM_TEMPERATURE = 0
18
+ LLM_MAX_ITERATIONS = 5
19
+
20
+ # ==================== TOOL CONFIGURATION ====================
21
+ WIKIPEDIA_MAX_PAGES = 2
22
+ WIKIPEDIA_CHAR_LIMIT = 8_000
23
+
24
+ YOUTUBE_CHAR_LIMIT = 10_000
25
+
26
+ WEB_SEARCH_RESULTS_LIMIT = 3
27
+
28
+ EXCEL_PREVIEW_ROWS = 50
29
+
30
+ # ==================== OUTPUT CONFIGURATION ====================
31
+ OUTPUT_FILE = "/home/nitin/AI/hfagent/results.jsonl"
32
+ FINAL_ANSWER_MAX_LENGTH = 100
33
+ REASONING_TRACE_MAX_LENGTH = 200
34
+
35
+ # ==================== TOOL NAMES ====================
36
+ TOOL_NAMES = {
37
+ "WEB_SEARCH": "web_search",
38
+ "WIKI_SEARCH": "wikisearch",
39
+ "YOUTUBE_TRANSCRIPT": "youtube_transcript",
40
+ "EXCEL_ANALYSIS": "load_and_analyze_excel_file",
41
+ "IMAGE_TEXT": "extract_text_from_image",
42
+ "AUDIO_TRANSCRIBE": "transcribe_audio",
43
+ "ADD": "addition_tool",
44
+ "SUBTRACT": "subtraction_tool",
45
+ "MULTIPLY": "multiplication_tool",
46
+ "NONE": "none",
47
+ }
48
+
49
+ # ==================== VALIDATION ====================
50
+ VALID_EXCEL_EXTENSIONS = (".xlsx", ".xls", ".csv")
51
+ VALID_IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".gif")
52
+ VALID_AUDIO_EXTENSIONS = (".mp3", ".wav", ".m4a", ".flac", ".ogg")
customtools.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom tools for the GAIA agent.
3
+ Includes tools for web search, file analysis, text extraction, and more.
4
+ """
5
+
6
+ import os
7
+ import re
8
+ import subprocess
9
+ from tempfile import NamedTemporaryFile
10
+ from pathlib import Path
11
+
12
+ import cv2
13
+ import pandas as pds
14
+ import pytesseract
15
+ import whisper
16
+ from dotenv import load_dotenv
17
+ from langchain_core.messages import HumanMessage
18
+ from langchain_core.tools import tool
19
+ from langchain_community.document_loaders import WikipediaLoader
20
+ from langchain_openrouter import ChatOpenRouter
21
+ from tavily import TavilyClient
22
+ from youtube_transcript_api import YouTubeTranscriptApi
23
+
24
+ from config import (
25
+ OPENROUTER_API_KEY,
26
+ TAVILY_API_KEY,
27
+ LLM_MODEL,
28
+ LLM_TEMPERATURE,
29
+ WIKIPEDIA_MAX_PAGES,
30
+ WIKIPEDIA_CHAR_LIMIT,
31
+ YOUTUBE_CHAR_LIMIT,
32
+ WEB_SEARCH_RESULTS_LIMIT,
33
+ EXCEL_PREVIEW_ROWS,
34
+ )
35
+ from prompts import (
36
+ EXCEL_ANALYSIS_PROMPT_TEMPLATE,
37
+ WEB_SEARCH_EXTRACTION_PROMPT_TEMPLATE,
38
+ )
39
+
40
+ load_dotenv()
41
+ @tool
42
+ def wikisearch(query: str, max_pages: int = None) -> str:
43
+ """Search Wikipedia pages and return concatenated page texts."""
44
+ max_pages = max_pages or WIKIPEDIA_MAX_PAGES
45
+ print(f"wikisearch called with query: {query}, max_pages: {max_pages}")
46
+
47
+ try:
48
+ docs = WikipediaLoader(query=query, load_max_docs=max_pages).load()
49
+ joined = "\n\n---\n\n".join(d.page_content for d in docs)
50
+ return joined[:WIKIPEDIA_CHAR_LIMIT]
51
+ except Exception as e:
52
+ return f"Error searching Wikipedia: {str(e)}"
53
+
54
+
55
+ @tool
56
+ def youtube_transcript(url: str, chars: int = None) -> str:
57
+ """Fetch YouTube video transcript."""
58
+ chars = chars or YOUTUBE_CHAR_LIMIT
59
+ video_id_match = re.search(r"[?&]v=([A-Za-z0-9_\-]{11})", url)
60
+
61
+ if not video_id_match:
62
+ return "Error: Could not extract video ID from URL"
63
+
64
+ try:
65
+ transcript = YouTubeTranscriptApi.get_transcript(video_id_match.group(1))
66
+ text = " ".join(piece["text"] for piece in transcript)
67
+ return text[:chars]
68
+ except Exception as exc:
69
+ print(f"Error fetching YouTube transcript: {exc}")
70
+ return f"Error fetching transcript: {str(exc)}"
71
+
72
+
73
+ @tool
74
+ def web_search(query: str) -> str:
75
+ """Perform a web search and extract concise factual answers."""
76
+ print(f"web_search called with query: {query}")
77
+
78
+ if not TAVILY_API_KEY:
79
+ return "Error: TAVILY_API_KEY not set in environment"
80
+
81
+ try:
82
+ tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
83
+ search_results = tavily_client.search(query)
84
+ print(f"Search results obtained")
85
+
86
+ # Format results as a readable string
87
+ if search_results and isinstance(search_results, dict) and "results" in search_results:
88
+ formatted = "\n".join([
89
+ f"- {r.get('title', '')}: {r.get('content', '')[:200]}"
90
+ for r in search_results["results"][:WEB_SEARCH_RESULTS_LIMIT]
91
+ ])
92
+ return formatted if formatted else "No results found"
93
+
94
+ return str(search_results)
95
+ except Exception as e:
96
+ print(f"Error during web search: {e}")
97
+ return f"Error during web search: {str(e)}"
98
+
99
+ @tool
100
+ def addition_tool(a: str, b: str) -> str:
101
+ """Add two numbers represented as strings."""
102
+ try:
103
+ num_a = float(a)
104
+ num_b = float(b)
105
+ result = num_a + num_b
106
+ return str(result)
107
+ except ValueError:
108
+ return "Invalid input: both a and b must be numbers."
109
+ except Exception as e:
110
+ return f"Error during addition: {str(e)}"
111
+
112
+ @tool
113
+ def subtraction_tool(a: str, b: str) -> str:
114
+ """Subtract two numbers represented as strings."""
115
+ try:
116
+ num_a = float(a)
117
+ num_b = float(b)
118
+ result = num_a - num_b
119
+ return str(result)
120
+ except ValueError:
121
+ return "Invalid input: both a and b must be numbers."
122
+ except Exception as e:
123
+ return f"Error during subtraction: {str(e)}"
124
+
125
+
126
+ @tool
127
+ def multiplication_tool(a: str, b: str) -> str:
128
+ """Multiply two numbers represented as strings."""
129
+ try:
130
+ num_a = float(a)
131
+ num_b = float(b)
132
+ result = num_a * num_b
133
+ return str(result)
134
+ except ValueError:
135
+ return "Invalid input: both a and b must be numbers."
136
+ except Exception as e:
137
+ return f"Error during multiplication: {str(e)}"
138
+
139
+
140
+ @tool
141
+ def division_tool(a: str, b: str) -> str:
142
+ """Divide two numbers represented as strings."""
143
+ try:
144
+ num_a = float(a)
145
+ num_b = float(b)
146
+ if num_b == 0:
147
+ return "Error: Division by zero is not allowed."
148
+ result = num_a / num_b
149
+ return str(result)
150
+ except ValueError:
151
+ return "Invalid input: both a and b must be numbers."
152
+ except Exception as e:
153
+ return f"Error during division: {str(e)}"
154
+
155
+
156
+
157
+
158
+ @tool
159
+ def extract_text_from_image(image_path: str) -> str:
160
+ """
161
+ Extract text from image files using OCR.
162
+ Works with .jpg, .png, .bmp, .tiff formats only.
163
+
164
+ Args:
165
+ image_path: Full path to the image file
166
+ """
167
+ try:
168
+ img = cv2.imread(image_path)
169
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
170
+ thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
171
+ thresh = cv2.bitwise_not(thresh)
172
+
173
+ custom_config = r'--oem 3 --psm 6'
174
+ full_text = pytesseract.image_to_string(thresh, config=custom_config)
175
+
176
+ return f"Extracted text from image:\n\n{full_text}"
177
+ except Exception as e:
178
+ return f"Error extracting text from image: {str(e)}"
179
+
180
+
181
+
182
+ @tool
183
+ def run_python(code: str) -> str:
184
+ """Execute Python code in a subprocess and return output."""
185
+ try:
186
+ with NamedTemporaryFile(delete=False, suffix=".py", mode="w") as f:
187
+ f.write(code)
188
+ path = f.name
189
+
190
+ proc = subprocess.run(
191
+ ["python", path], capture_output=True, text=True, timeout=45
192
+ )
193
+
194
+ out = proc.stdout.strip().splitlines()
195
+ return out[-1] if out else ""
196
+ except Exception as exc:
197
+ print(f"Error executing Python code: {exc}")
198
+ return f"py_error:{exc}"
199
+
200
+ @tool
201
+ def load_and_analyze_excel_file(query: str, file_path: str) -> str:
202
+ """
203
+ Load and analyze data from Excel/CSV files (.xlsx, .xls, .csv).
204
+
205
+ Args:
206
+ query: Data analysis question (e.g., "Count records where status=active")
207
+ file_path: Full path to the Excel/CSV file
208
+ """
209
+ print(f"load_and_analyze_excel_file called - Query: {query}, File: {file_path}")
210
+
211
+ try:
212
+ # Read the file based on extension
213
+ if file_path.lower().endswith(".csv"):
214
+ df = pds.read_csv(file_path)
215
+ else:
216
+ df = pds.read_excel(file_path)
217
+
218
+ # Create basic data summary
219
+ result = f"File loaded successfully.\n"
220
+ result += f"Rows: {len(df)}, Columns: {len(df.columns)}\n"
221
+ result += f"Column names: {', '.join(df.columns.tolist())}\n\n"
222
+
223
+ # Prepare data context for LLM
224
+ data_summary = f"DataFrame:\n{df.to_string(max_rows=EXCEL_PREVIEW_ROWS)}\n\nData Types:\n{df.dtypes.to_string()}"
225
+
226
+ # Create analysis prompt
227
+ analysis_prompt = EXCEL_ANALYSIS_PROMPT_TEMPLATE.format(
228
+ data_summary=data_summary,
229
+ query=query
230
+ )
231
+
232
+ # Get LLM analysis
233
+ tool_llm = ChatOpenRouter(
234
+ model=LLM_MODEL,
235
+ temperature=LLM_TEMPERATURE,
236
+ api_key=OPENROUTER_API_KEY,
237
+ )
238
+
239
+ message = HumanMessage(content=analysis_prompt)
240
+ llm_response = tool_llm.invoke([message])
241
+
242
+ result += f"Analysis:\n{llm_response.content}"
243
+ print(f"Excel analysis completed")
244
+ return result
245
+
246
+ except Exception as e:
247
+ return f"Error analyzing Excel file: {str(e)}"
248
+
249
+
250
+
251
+ @tool
252
+ def transcribe_audio(audio_file: str) -> str:
253
+ """Transcribe audio files and return the transcript."""
254
+ try:
255
+ model = whisper.load_model("base")
256
+ output = model.transcribe(audio=str(Path(audio_file)), language='en')
257
+ print(f"Audio transcription completed")
258
+ return output['text']
259
+ except Exception as exc:
260
+ print(f"Error transcribing audio: {exc}")
261
+ return f"transcription_error:{exc}"
gaia_agent.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GAIA Agent - Multi-step reasoning agent for complex tasks.
3
+ Uses LanggraphStateGraph for workflow orchestration and multiple specialized tools.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ from typing import List, Dict, Any, Optional, Literal
9
+ from pathlib import Path
10
+
11
+ from dotenv import load_dotenv
12
+ from pydantic import BaseModel, Field
13
+ from langchain_core.messages import HumanMessage
14
+ from langchain_openrouter import ChatOpenRouter
15
+ from langgraph.graph import StateGraph, START, END
16
+ from langgraph.checkpoint.memory import MemorySaver
17
+ from typing import TypedDict
18
+
19
+ from customtools import (
20
+ load_and_analyze_excel_file,
21
+ extract_text_from_image,
22
+ web_search,
23
+ wikisearch,
24
+ youtube_transcript,
25
+ addition_tool,
26
+ subtraction_tool,
27
+ multiplication_tool,
28
+ transcribe_audio,
29
+ )
30
+ from config import (
31
+ OPENROUTER_API_KEY,
32
+ LLM_MODEL,
33
+ LLM_TEMPERATURE,
34
+ OUTPUT_FILE,
35
+ FINAL_ANSWER_MAX_LENGTH,
36
+ REASONING_TRACE_MAX_LENGTH,
37
+ )
38
+ from prompts import (
39
+ PLANNER_PROMPT_TEMPLATE,
40
+ FINALIZER_PROMPT_TEMPLATE,
41
+ )
42
+
43
+ load_dotenv()
44
+
45
+ memory = MemorySaver()
46
+
47
+
48
+ def connect_models():
49
+ """Initialize and return the LLM instance."""
50
+ try:
51
+ print(f"Connecting to LLM: {LLM_MODEL}")
52
+ llm = ChatOpenRouter(
53
+ model=LLM_MODEL,
54
+ temperature=LLM_TEMPERATURE,
55
+ api_key=OPENROUTER_API_KEY,
56
+ )
57
+ return llm
58
+ except Exception as e:
59
+ print(f"Error initializing LLM: {e}")
60
+ raise
61
+
62
+
63
+ # Tool registry
64
+ TOOLS = {
65
+ "web_search": web_search,
66
+ "addition_tool": addition_tool,
67
+ "subtraction_tool": subtraction_tool,
68
+ "multiplication_tool": multiplication_tool,
69
+ "youtube_transcript": youtube_transcript,
70
+ "load_and_analyze_excel_file": load_and_analyze_excel_file,
71
+ "extract_text_from_image": extract_text_from_image,
72
+ "wikisearch": wikisearch,
73
+ "transcribe_audio": transcribe_audio,
74
+ }
75
+
76
+
77
+ class AgentState(TypedDict):
78
+ """State structure for the agent workflow."""
79
+ question: str
80
+ plan: List[Dict[str, Any]]
81
+ current_step: int
82
+ selected_tool: Optional[str]
83
+ tool_input: Optional[str]
84
+ tool_output: Optional[str]
85
+ intermediate_results: List[Dict[str, Any]]
86
+ final_answer: Optional[str]
87
+ done: bool
88
+
89
+
90
+ class Step(BaseModel):
91
+ """Represents a single step in the plan."""
92
+ step_number: int
93
+ description: str
94
+ tool: Literal[
95
+ "web_search",
96
+ "wikisearch",
97
+ "youtube_transcript",
98
+ "load_and_analyze_excel_file",
99
+ "extract_text_from_image",
100
+ "transcribe_audio",
101
+ "addition_tool",
102
+ "subtraction_tool",
103
+ "multiplication_tool",
104
+ "none",
105
+ ]
106
+ tool_input: str
107
+
108
+
109
+ class Plan(BaseModel):
110
+ """Structured plan with multiple steps."""
111
+ steps: List[Step]
112
+ def planner_node(state: AgentState):
113
+ """Planner node: breaks down question into steps."""
114
+ prompt = PLANNER_PROMPT_TEMPLATE.format(question=state['question'])
115
+
116
+ planner_llm = llm.with_structured_output(Plan, method="json_schema")
117
+ response = planner_llm.invoke(prompt)
118
+
119
+ print(f"Planner generated {len(response.steps)} steps")
120
+
121
+ return {
122
+ **state,
123
+ "plan": [step.dict() for step in response.steps],
124
+ "current_step": 0,
125
+ "intermediate_results": [],
126
+ "done": False,
127
+ }
128
+
129
+
130
+
131
+ def execute_step_node(state: AgentState):
132
+ """Execute step node: prepares tool invocation."""
133
+ step = state["plan"][state["current_step"]]
134
+ tool_name = step.get("tool", "none")
135
+
136
+ print(f"Executing step {state['current_step'] + 1}/{len(state['plan'])}: {tool_name}")
137
+
138
+ return {
139
+ **state,
140
+ "tool_input": step.get("tool_input"),
141
+ "selected_tool": tool_name,
142
+ }
143
+
144
+
145
+ def tool_node(state: AgentState):
146
+ """Tool execution node: invokes the selected tool."""
147
+ tool_name = state.get("selected_tool")
148
+ tool_input = state.get("tool_input")
149
+
150
+ if tool_name == "none":
151
+ return {**state, "tool_output": tool_input}
152
+
153
+ print(f"Invoking tool: {tool_name}")
154
+ tool = TOOLS.get(tool_name)
155
+
156
+ # Special handling for load_and_analyze_excel_file: parse query|file_path format
157
+ if tool_name == "load_and_analyze_excel_file" and isinstance(tool_input, str) and "|" in tool_input:
158
+ parts = tool_input.split("|", 1)
159
+ query = parts[0].strip()
160
+ file_path = parts[1].strip()
161
+ tool_input = {"query": query, "file_path": file_path}
162
+ print(f"Parsed Excel input - Query: '{query[:50]}...', File: '{file_path}'")
163
+
164
+ # Special handling for math tools: parse "a,b" format
165
+ if tool in (addition_tool, subtraction_tool, multiplication_tool):
166
+ try:
167
+ a, b = tool_input.split(",")
168
+ tool_input = {"a": a.strip(), "b": b.strip()}
169
+ except Exception as e:
170
+ print(f"Error parsing math tool input: {e}")
171
+ return {**state, "tool_output": f"Error parsing input: {e}"}
172
+
173
+ if not tool:
174
+ return {**state, "tool_output": f"Unknown tool: {tool_name}"}
175
+
176
+ try:
177
+ result = tool.invoke(tool_input)
178
+ except Exception as e:
179
+ print(f"Error invoking tool {tool_name}: {e}")
180
+ result = f"Tool error: {str(e)}"
181
+
182
+ return {**state, "tool_output": result}
183
+
184
+
185
+ def update_state_node(state: AgentState):
186
+ """Update state node: records tool output and progresses to next step."""
187
+ step = state["plan"][state["current_step"]]
188
+
189
+ state["intermediate_results"].append({
190
+ "step": step,
191
+ "output": state["tool_output"]
192
+ })
193
+
194
+ next_step = state["current_step"] + 1
195
+ done = next_step >= len(state["plan"])
196
+
197
+ return {
198
+ **state,
199
+ "current_step": next_step,
200
+ "done": done,
201
+ }
202
+
203
+
204
+
205
+
206
+ def should_continue(state: AgentState):
207
+ """Conditional edge: determines if workflow should continue or finalize."""
208
+ return "finalize" if state["done"] else "continue"
209
+
210
+
211
+ def finalizer_node(state: AgentState):
212
+ """Finalizer node: summarizes results and generates final answer."""
213
+ # Format intermediate results for the finalizer
214
+ results_text = "\n".join([
215
+ f"Step {i+1}: {r['step'].get('description', '')}\n Output: {str(r['output'])[:100]}..."
216
+ for i, r in enumerate(state["intermediate_results"])
217
+ ])
218
+
219
+ prompt = FINALIZER_PROMPT_TEMPLATE.format(
220
+ question=state['question'],
221
+ intermediate_results=results_text
222
+ )
223
+
224
+ response = llm.invoke(prompt)
225
+
226
+ return {
227
+ **state,
228
+ "final_answer": response.content,
229
+ }
230
+
231
+
232
+
233
+ def create_agent_workflow():
234
+
235
+ graph = StateGraph(AgentState)
236
+
237
+ # Nodes
238
+ graph.add_node("planner", planner_node)
239
+ graph.add_node("executor", execute_step_node)
240
+ graph.add_node("tool", tool_node)
241
+ graph.add_node("updater", update_state_node)
242
+ graph.add_node("finalizer", finalizer_node)
243
+ # Entry
244
+ graph.set_entry_point("planner")
245
+
246
+ # Flow
247
+ graph.add_edge("planner", "executor")
248
+ graph.add_edge("executor", "tool")
249
+ graph.add_edge("tool", "updater")
250
+
251
+ # Loop
252
+ graph.add_conditional_edges(
253
+ "updater",
254
+ should_continue,
255
+ {
256
+ "continue": "executor",
257
+ "finalize": "finalizer"
258
+ }
259
+ )
260
+
261
+ # End
262
+ graph.add_edge("finalizer", END)
263
+
264
+ return graph.compile()
265
+
266
+
267
+ def format_reasoning_trace(intermediate_results: List[Dict[str, Any]]) -> str:
268
+ """Format intermediate results into a readable reasoning trace"""
269
+ trace_lines = []
270
+ for result in intermediate_results:
271
+ step = result.get("step", {})
272
+ output = result.get("output", "")
273
+ description = step.get("description", "Unknown step")
274
+ tool = step.get("tool", "none")
275
+
276
+ trace_lines.append(f"Step: {description}")
277
+ trace_lines.append(f" Tool: {tool}")
278
+ trace_lines.append(f" Output: {output[:200]}{'...' if len(str(output)) > 200 else ''}")
279
+
280
+ return "\n".join(trace_lines)
281
+
282
+
283
+ def process_questions(questions_file: str = None, questions_list: List[str] = None) -> str:
284
+ """
285
+ Process multiple questions and save results to a file
286
+
287
+ Args:
288
+ questions_file: Path to a file containing questions (one per line)
289
+ questions_list: List of questions to process
290
+
291
+ Returns:
292
+ Path to the output file with results
293
+ """
294
+ global llm
295
+ llm = connect_models()
296
+ print(f"LLM available: {llm}")
297
+ agent = create_agent_workflow()
298
+
299
+ # Get questions from either file or list
300
+ if questions_file:
301
+ with open(questions_file, 'r') as f:
302
+ questions = [q.strip() for q in f.readlines() if q.strip()]
303
+ elif questions_list:
304
+ questions = questions_list
305
+ else:
306
+ raise ValueError("Either questions_file or questions_list must be provided")
307
+
308
+ results = []
309
+
310
+ for idx, question in enumerate(questions, 1):
311
+ task_id = f"task_id_{idx}"
312
+ print(f"\n{'='*80}")
313
+ print(f"Processing {task_id}: {question[:80]}...")
314
+ print(f"{'='*80}")
315
+
316
+ try:
317
+ # Run the agent
318
+ result = agent.invoke({
319
+ "question": question
320
+ })
321
+
322
+ # Extract the final answer and reasoning trace
323
+ final_answer = result.get("final_answer", "No answer generated")
324
+ intermediate_results = result.get("intermediate_results", [])
325
+
326
+ # Format the reasoning trace
327
+ reasoning_trace = format_reasoning_trace(intermediate_results)
328
+
329
+ # Create the result object
330
+ task_result = {
331
+ "task_id": task_id,
332
+ "model_answer": final_answer,
333
+ "reasoning_trace": reasoning_trace
334
+ }
335
+
336
+ results.append(task_result)
337
+
338
+ print(f"Completed {task_id}")
339
+ print(f"Answer: {final_answer[:100]}...")
340
+
341
+ except Exception as e:
342
+ print(f"✗ Error processing {task_id}: {str(e)}")
343
+ task_result = {
344
+ "task_id": task_id,
345
+ "model_answer": f"Error: {str(e)}",
346
+ "reasoning_trace": "Failed to execute agent"
347
+ }
348
+ results.append(task_result)
349
+
350
+ # Save results to file
351
+ output_file = "/home/nitin/AI/hfagent/results.jsonl"
352
+ with open(output_file, 'w') as f:
353
+ for result in results:
354
+ f.write(json.dumps(result) + '\n')
355
+
356
+ print(f"\n{'='*80}")
357
+ print(f"All tasks completed. Results saved to: {output_file}")
358
+ print(f"{'='*80}")
359
+
360
+ return output_file
361
+
362
+
363
+
364
+
365
+
366
+ if __name__ == "__main__":
367
+
368
+ global llm
369
+ # Example questions to process
370
+ questions = [
371
+
372
+ #"What is the square of the population of France in millions?",
373
+ #"What is 50 plus 75?"
374
+ ]
375
+
376
+ # Process all questions
377
+ output_file = process_questions(questions_list=questions)
378
+
379
+ # Print the results
380
+ print("\nResults from file:")
381
+ with open(output_file, 'r') as f:
382
+ for line in f:
383
+ result = json.loads(line)
384
+ print(f"\nTask ID: {result['task_id']}")
385
+ print(f"Answer: {result['model_answer']}")
386
+ print(f"Reasoning:\n{result['reasoning_trace']}")