Spaces:
Running
Running
Upload 3 files
Browse files- config.py +52 -0
- customtools.py +261 -0
- gaia_agent.py +386 -0
config.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration and constants for the GAIA agent.
|
| 3 |
+
Centralized configuration for easy management and customization.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
# ==================== API KEYS ====================
|
| 12 |
+
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", "")
|
| 13 |
+
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY", "")
|
| 14 |
+
|
| 15 |
+
# ==================== LLM CONFIGURATION ====================
|
| 16 |
+
LLM_MODEL = "inclusionai/Ling-2.6-1T:free"
|
| 17 |
+
LLM_TEMPERATURE = 0
|
| 18 |
+
LLM_MAX_ITERATIONS = 5
|
| 19 |
+
|
| 20 |
+
# ==================== TOOL CONFIGURATION ====================
|
| 21 |
+
WIKIPEDIA_MAX_PAGES = 2
|
| 22 |
+
WIKIPEDIA_CHAR_LIMIT = 8_000
|
| 23 |
+
|
| 24 |
+
YOUTUBE_CHAR_LIMIT = 10_000
|
| 25 |
+
|
| 26 |
+
WEB_SEARCH_RESULTS_LIMIT = 3
|
| 27 |
+
|
| 28 |
+
EXCEL_PREVIEW_ROWS = 50
|
| 29 |
+
|
| 30 |
+
# ==================== OUTPUT CONFIGURATION ====================
|
| 31 |
+
OUTPUT_FILE = "/home/nitin/AI/hfagent/results.jsonl"
|
| 32 |
+
FINAL_ANSWER_MAX_LENGTH = 100
|
| 33 |
+
REASONING_TRACE_MAX_LENGTH = 200
|
| 34 |
+
|
| 35 |
+
# ==================== TOOL NAMES ====================
|
| 36 |
+
TOOL_NAMES = {
|
| 37 |
+
"WEB_SEARCH": "web_search",
|
| 38 |
+
"WIKI_SEARCH": "wikisearch",
|
| 39 |
+
"YOUTUBE_TRANSCRIPT": "youtube_transcript",
|
| 40 |
+
"EXCEL_ANALYSIS": "load_and_analyze_excel_file",
|
| 41 |
+
"IMAGE_TEXT": "extract_text_from_image",
|
| 42 |
+
"AUDIO_TRANSCRIBE": "transcribe_audio",
|
| 43 |
+
"ADD": "addition_tool",
|
| 44 |
+
"SUBTRACT": "subtraction_tool",
|
| 45 |
+
"MULTIPLY": "multiplication_tool",
|
| 46 |
+
"NONE": "none",
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
# ==================== VALIDATION ====================
|
| 50 |
+
VALID_EXCEL_EXTENSIONS = (".xlsx", ".xls", ".csv")
|
| 51 |
+
VALID_IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".gif")
|
| 52 |
+
VALID_AUDIO_EXTENSIONS = (".mp3", ".wav", ".m4a", ".flac", ".ogg")
|
customtools.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom tools for the GAIA agent.
|
| 3 |
+
Includes tools for web search, file analysis, text extraction, and more.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
import subprocess
|
| 9 |
+
from tempfile import NamedTemporaryFile
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import cv2
|
| 13 |
+
import pandas as pds
|
| 14 |
+
import pytesseract
|
| 15 |
+
import whisper
|
| 16 |
+
from dotenv import load_dotenv
|
| 17 |
+
from langchain_core.messages import HumanMessage
|
| 18 |
+
from langchain_core.tools import tool
|
| 19 |
+
from langchain_community.document_loaders import WikipediaLoader
|
| 20 |
+
from langchain_openrouter import ChatOpenRouter
|
| 21 |
+
from tavily import TavilyClient
|
| 22 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
| 23 |
+
|
| 24 |
+
from config import (
|
| 25 |
+
OPENROUTER_API_KEY,
|
| 26 |
+
TAVILY_API_KEY,
|
| 27 |
+
LLM_MODEL,
|
| 28 |
+
LLM_TEMPERATURE,
|
| 29 |
+
WIKIPEDIA_MAX_PAGES,
|
| 30 |
+
WIKIPEDIA_CHAR_LIMIT,
|
| 31 |
+
YOUTUBE_CHAR_LIMIT,
|
| 32 |
+
WEB_SEARCH_RESULTS_LIMIT,
|
| 33 |
+
EXCEL_PREVIEW_ROWS,
|
| 34 |
+
)
|
| 35 |
+
from prompts import (
|
| 36 |
+
EXCEL_ANALYSIS_PROMPT_TEMPLATE,
|
| 37 |
+
WEB_SEARCH_EXTRACTION_PROMPT_TEMPLATE,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
load_dotenv()
|
| 41 |
+
@tool
|
| 42 |
+
def wikisearch(query: str, max_pages: int = None) -> str:
|
| 43 |
+
"""Search Wikipedia pages and return concatenated page texts."""
|
| 44 |
+
max_pages = max_pages or WIKIPEDIA_MAX_PAGES
|
| 45 |
+
print(f"wikisearch called with query: {query}, max_pages: {max_pages}")
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
docs = WikipediaLoader(query=query, load_max_docs=max_pages).load()
|
| 49 |
+
joined = "\n\n---\n\n".join(d.page_content for d in docs)
|
| 50 |
+
return joined[:WIKIPEDIA_CHAR_LIMIT]
|
| 51 |
+
except Exception as e:
|
| 52 |
+
return f"Error searching Wikipedia: {str(e)}"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@tool
|
| 56 |
+
def youtube_transcript(url: str, chars: int = None) -> str:
|
| 57 |
+
"""Fetch YouTube video transcript."""
|
| 58 |
+
chars = chars or YOUTUBE_CHAR_LIMIT
|
| 59 |
+
video_id_match = re.search(r"[?&]v=([A-Za-z0-9_\-]{11})", url)
|
| 60 |
+
|
| 61 |
+
if not video_id_match:
|
| 62 |
+
return "Error: Could not extract video ID from URL"
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
transcript = YouTubeTranscriptApi.get_transcript(video_id_match.group(1))
|
| 66 |
+
text = " ".join(piece["text"] for piece in transcript)
|
| 67 |
+
return text[:chars]
|
| 68 |
+
except Exception as exc:
|
| 69 |
+
print(f"Error fetching YouTube transcript: {exc}")
|
| 70 |
+
return f"Error fetching transcript: {str(exc)}"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@tool
|
| 74 |
+
def web_search(query: str) -> str:
|
| 75 |
+
"""Perform a web search and extract concise factual answers."""
|
| 76 |
+
print(f"web_search called with query: {query}")
|
| 77 |
+
|
| 78 |
+
if not TAVILY_API_KEY:
|
| 79 |
+
return "Error: TAVILY_API_KEY not set in environment"
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
|
| 83 |
+
search_results = tavily_client.search(query)
|
| 84 |
+
print(f"Search results obtained")
|
| 85 |
+
|
| 86 |
+
# Format results as a readable string
|
| 87 |
+
if search_results and isinstance(search_results, dict) and "results" in search_results:
|
| 88 |
+
formatted = "\n".join([
|
| 89 |
+
f"- {r.get('title', '')}: {r.get('content', '')[:200]}"
|
| 90 |
+
for r in search_results["results"][:WEB_SEARCH_RESULTS_LIMIT]
|
| 91 |
+
])
|
| 92 |
+
return formatted if formatted else "No results found"
|
| 93 |
+
|
| 94 |
+
return str(search_results)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f"Error during web search: {e}")
|
| 97 |
+
return f"Error during web search: {str(e)}"
|
| 98 |
+
|
| 99 |
+
@tool
|
| 100 |
+
def addition_tool(a: str, b: str) -> str:
|
| 101 |
+
"""Add two numbers represented as strings."""
|
| 102 |
+
try:
|
| 103 |
+
num_a = float(a)
|
| 104 |
+
num_b = float(b)
|
| 105 |
+
result = num_a + num_b
|
| 106 |
+
return str(result)
|
| 107 |
+
except ValueError:
|
| 108 |
+
return "Invalid input: both a and b must be numbers."
|
| 109 |
+
except Exception as e:
|
| 110 |
+
return f"Error during addition: {str(e)}"
|
| 111 |
+
|
| 112 |
+
@tool
|
| 113 |
+
def subtraction_tool(a: str, b: str) -> str:
|
| 114 |
+
"""Subtract two numbers represented as strings."""
|
| 115 |
+
try:
|
| 116 |
+
num_a = float(a)
|
| 117 |
+
num_b = float(b)
|
| 118 |
+
result = num_a - num_b
|
| 119 |
+
return str(result)
|
| 120 |
+
except ValueError:
|
| 121 |
+
return "Invalid input: both a and b must be numbers."
|
| 122 |
+
except Exception as e:
|
| 123 |
+
return f"Error during subtraction: {str(e)}"
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@tool
|
| 127 |
+
def multiplication_tool(a: str, b: str) -> str:
|
| 128 |
+
"""Multiply two numbers represented as strings."""
|
| 129 |
+
try:
|
| 130 |
+
num_a = float(a)
|
| 131 |
+
num_b = float(b)
|
| 132 |
+
result = num_a * num_b
|
| 133 |
+
return str(result)
|
| 134 |
+
except ValueError:
|
| 135 |
+
return "Invalid input: both a and b must be numbers."
|
| 136 |
+
except Exception as e:
|
| 137 |
+
return f"Error during multiplication: {str(e)}"
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@tool
|
| 141 |
+
def division_tool(a: str, b: str) -> str:
|
| 142 |
+
"""Divide two numbers represented as strings."""
|
| 143 |
+
try:
|
| 144 |
+
num_a = float(a)
|
| 145 |
+
num_b = float(b)
|
| 146 |
+
if num_b == 0:
|
| 147 |
+
return "Error: Division by zero is not allowed."
|
| 148 |
+
result = num_a / num_b
|
| 149 |
+
return str(result)
|
| 150 |
+
except ValueError:
|
| 151 |
+
return "Invalid input: both a and b must be numbers."
|
| 152 |
+
except Exception as e:
|
| 153 |
+
return f"Error during division: {str(e)}"
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@tool
|
| 159 |
+
def extract_text_from_image(image_path: str) -> str:
|
| 160 |
+
"""
|
| 161 |
+
Extract text from image files using OCR.
|
| 162 |
+
Works with .jpg, .png, .bmp, .tiff formats only.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
image_path: Full path to the image file
|
| 166 |
+
"""
|
| 167 |
+
try:
|
| 168 |
+
img = cv2.imread(image_path)
|
| 169 |
+
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 170 |
+
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
|
| 171 |
+
thresh = cv2.bitwise_not(thresh)
|
| 172 |
+
|
| 173 |
+
custom_config = r'--oem 3 --psm 6'
|
| 174 |
+
full_text = pytesseract.image_to_string(thresh, config=custom_config)
|
| 175 |
+
|
| 176 |
+
return f"Extracted text from image:\n\n{full_text}"
|
| 177 |
+
except Exception as e:
|
| 178 |
+
return f"Error extracting text from image: {str(e)}"
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@tool
|
| 183 |
+
def run_python(code: str) -> str:
|
| 184 |
+
"""Execute Python code in a subprocess and return output."""
|
| 185 |
+
try:
|
| 186 |
+
with NamedTemporaryFile(delete=False, suffix=".py", mode="w") as f:
|
| 187 |
+
f.write(code)
|
| 188 |
+
path = f.name
|
| 189 |
+
|
| 190 |
+
proc = subprocess.run(
|
| 191 |
+
["python", path], capture_output=True, text=True, timeout=45
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
out = proc.stdout.strip().splitlines()
|
| 195 |
+
return out[-1] if out else ""
|
| 196 |
+
except Exception as exc:
|
| 197 |
+
print(f"Error executing Python code: {exc}")
|
| 198 |
+
return f"py_error:{exc}"
|
| 199 |
+
|
| 200 |
+
@tool
|
| 201 |
+
def load_and_analyze_excel_file(query: str, file_path: str) -> str:
|
| 202 |
+
"""
|
| 203 |
+
Load and analyze data from Excel/CSV files (.xlsx, .xls, .csv).
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
query: Data analysis question (e.g., "Count records where status=active")
|
| 207 |
+
file_path: Full path to the Excel/CSV file
|
| 208 |
+
"""
|
| 209 |
+
print(f"load_and_analyze_excel_file called - Query: {query}, File: {file_path}")
|
| 210 |
+
|
| 211 |
+
try:
|
| 212 |
+
# Read the file based on extension
|
| 213 |
+
if file_path.lower().endswith(".csv"):
|
| 214 |
+
df = pds.read_csv(file_path)
|
| 215 |
+
else:
|
| 216 |
+
df = pds.read_excel(file_path)
|
| 217 |
+
|
| 218 |
+
# Create basic data summary
|
| 219 |
+
result = f"File loaded successfully.\n"
|
| 220 |
+
result += f"Rows: {len(df)}, Columns: {len(df.columns)}\n"
|
| 221 |
+
result += f"Column names: {', '.join(df.columns.tolist())}\n\n"
|
| 222 |
+
|
| 223 |
+
# Prepare data context for LLM
|
| 224 |
+
data_summary = f"DataFrame:\n{df.to_string(max_rows=EXCEL_PREVIEW_ROWS)}\n\nData Types:\n{df.dtypes.to_string()}"
|
| 225 |
+
|
| 226 |
+
# Create analysis prompt
|
| 227 |
+
analysis_prompt = EXCEL_ANALYSIS_PROMPT_TEMPLATE.format(
|
| 228 |
+
data_summary=data_summary,
|
| 229 |
+
query=query
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Get LLM analysis
|
| 233 |
+
tool_llm = ChatOpenRouter(
|
| 234 |
+
model=LLM_MODEL,
|
| 235 |
+
temperature=LLM_TEMPERATURE,
|
| 236 |
+
api_key=OPENROUTER_API_KEY,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
message = HumanMessage(content=analysis_prompt)
|
| 240 |
+
llm_response = tool_llm.invoke([message])
|
| 241 |
+
|
| 242 |
+
result += f"Analysis:\n{llm_response.content}"
|
| 243 |
+
print(f"Excel analysis completed")
|
| 244 |
+
return result
|
| 245 |
+
|
| 246 |
+
except Exception as e:
|
| 247 |
+
return f"Error analyzing Excel file: {str(e)}"
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
@tool
|
| 252 |
+
def transcribe_audio(audio_file: str) -> str:
|
| 253 |
+
"""Transcribe audio files and return the transcript."""
|
| 254 |
+
try:
|
| 255 |
+
model = whisper.load_model("base")
|
| 256 |
+
output = model.transcribe(audio=str(Path(audio_file)), language='en')
|
| 257 |
+
print(f"Audio transcription completed")
|
| 258 |
+
return output['text']
|
| 259 |
+
except Exception as exc:
|
| 260 |
+
print(f"Error transcribing audio: {exc}")
|
| 261 |
+
return f"transcription_error:{exc}"
|
gaia_agent.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GAIA Agent - Multi-step reasoning agent for complex tasks.
|
| 3 |
+
Uses LanggraphStateGraph for workflow orchestration and multiple specialized tools.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
from typing import List, Dict, Any, Optional, Literal
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
from pydantic import BaseModel, Field
|
| 13 |
+
from langchain_core.messages import HumanMessage
|
| 14 |
+
from langchain_openrouter import ChatOpenRouter
|
| 15 |
+
from langgraph.graph import StateGraph, START, END
|
| 16 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 17 |
+
from typing import TypedDict
|
| 18 |
+
|
| 19 |
+
from customtools import (
|
| 20 |
+
load_and_analyze_excel_file,
|
| 21 |
+
extract_text_from_image,
|
| 22 |
+
web_search,
|
| 23 |
+
wikisearch,
|
| 24 |
+
youtube_transcript,
|
| 25 |
+
addition_tool,
|
| 26 |
+
subtraction_tool,
|
| 27 |
+
multiplication_tool,
|
| 28 |
+
transcribe_audio,
|
| 29 |
+
)
|
| 30 |
+
from config import (
|
| 31 |
+
OPENROUTER_API_KEY,
|
| 32 |
+
LLM_MODEL,
|
| 33 |
+
LLM_TEMPERATURE,
|
| 34 |
+
OUTPUT_FILE,
|
| 35 |
+
FINAL_ANSWER_MAX_LENGTH,
|
| 36 |
+
REASONING_TRACE_MAX_LENGTH,
|
| 37 |
+
)
|
| 38 |
+
from prompts import (
|
| 39 |
+
PLANNER_PROMPT_TEMPLATE,
|
| 40 |
+
FINALIZER_PROMPT_TEMPLATE,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
load_dotenv()
|
| 44 |
+
|
| 45 |
+
memory = MemorySaver()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def connect_models():
|
| 49 |
+
"""Initialize and return the LLM instance."""
|
| 50 |
+
try:
|
| 51 |
+
print(f"Connecting to LLM: {LLM_MODEL}")
|
| 52 |
+
llm = ChatOpenRouter(
|
| 53 |
+
model=LLM_MODEL,
|
| 54 |
+
temperature=LLM_TEMPERATURE,
|
| 55 |
+
api_key=OPENROUTER_API_KEY,
|
| 56 |
+
)
|
| 57 |
+
return llm
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"Error initializing LLM: {e}")
|
| 60 |
+
raise
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# Tool registry
|
| 64 |
+
TOOLS = {
|
| 65 |
+
"web_search": web_search,
|
| 66 |
+
"addition_tool": addition_tool,
|
| 67 |
+
"subtraction_tool": subtraction_tool,
|
| 68 |
+
"multiplication_tool": multiplication_tool,
|
| 69 |
+
"youtube_transcript": youtube_transcript,
|
| 70 |
+
"load_and_analyze_excel_file": load_and_analyze_excel_file,
|
| 71 |
+
"extract_text_from_image": extract_text_from_image,
|
| 72 |
+
"wikisearch": wikisearch,
|
| 73 |
+
"transcribe_audio": transcribe_audio,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class AgentState(TypedDict):
|
| 78 |
+
"""State structure for the agent workflow."""
|
| 79 |
+
question: str
|
| 80 |
+
plan: List[Dict[str, Any]]
|
| 81 |
+
current_step: int
|
| 82 |
+
selected_tool: Optional[str]
|
| 83 |
+
tool_input: Optional[str]
|
| 84 |
+
tool_output: Optional[str]
|
| 85 |
+
intermediate_results: List[Dict[str, Any]]
|
| 86 |
+
final_answer: Optional[str]
|
| 87 |
+
done: bool
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class Step(BaseModel):
|
| 91 |
+
"""Represents a single step in the plan."""
|
| 92 |
+
step_number: int
|
| 93 |
+
description: str
|
| 94 |
+
tool: Literal[
|
| 95 |
+
"web_search",
|
| 96 |
+
"wikisearch",
|
| 97 |
+
"youtube_transcript",
|
| 98 |
+
"load_and_analyze_excel_file",
|
| 99 |
+
"extract_text_from_image",
|
| 100 |
+
"transcribe_audio",
|
| 101 |
+
"addition_tool",
|
| 102 |
+
"subtraction_tool",
|
| 103 |
+
"multiplication_tool",
|
| 104 |
+
"none",
|
| 105 |
+
]
|
| 106 |
+
tool_input: str
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class Plan(BaseModel):
|
| 110 |
+
"""Structured plan with multiple steps."""
|
| 111 |
+
steps: List[Step]
|
| 112 |
+
def planner_node(state: AgentState):
|
| 113 |
+
"""Planner node: breaks down question into steps."""
|
| 114 |
+
prompt = PLANNER_PROMPT_TEMPLATE.format(question=state['question'])
|
| 115 |
+
|
| 116 |
+
planner_llm = llm.with_structured_output(Plan, method="json_schema")
|
| 117 |
+
response = planner_llm.invoke(prompt)
|
| 118 |
+
|
| 119 |
+
print(f"Planner generated {len(response.steps)} steps")
|
| 120 |
+
|
| 121 |
+
return {
|
| 122 |
+
**state,
|
| 123 |
+
"plan": [step.dict() for step in response.steps],
|
| 124 |
+
"current_step": 0,
|
| 125 |
+
"intermediate_results": [],
|
| 126 |
+
"done": False,
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def execute_step_node(state: AgentState):
|
| 132 |
+
"""Execute step node: prepares tool invocation."""
|
| 133 |
+
step = state["plan"][state["current_step"]]
|
| 134 |
+
tool_name = step.get("tool", "none")
|
| 135 |
+
|
| 136 |
+
print(f"Executing step {state['current_step'] + 1}/{len(state['plan'])}: {tool_name}")
|
| 137 |
+
|
| 138 |
+
return {
|
| 139 |
+
**state,
|
| 140 |
+
"tool_input": step.get("tool_input"),
|
| 141 |
+
"selected_tool": tool_name,
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def tool_node(state: AgentState):
|
| 146 |
+
"""Tool execution node: invokes the selected tool."""
|
| 147 |
+
tool_name = state.get("selected_tool")
|
| 148 |
+
tool_input = state.get("tool_input")
|
| 149 |
+
|
| 150 |
+
if tool_name == "none":
|
| 151 |
+
return {**state, "tool_output": tool_input}
|
| 152 |
+
|
| 153 |
+
print(f"Invoking tool: {tool_name}")
|
| 154 |
+
tool = TOOLS.get(tool_name)
|
| 155 |
+
|
| 156 |
+
# Special handling for load_and_analyze_excel_file: parse query|file_path format
|
| 157 |
+
if tool_name == "load_and_analyze_excel_file" and isinstance(tool_input, str) and "|" in tool_input:
|
| 158 |
+
parts = tool_input.split("|", 1)
|
| 159 |
+
query = parts[0].strip()
|
| 160 |
+
file_path = parts[1].strip()
|
| 161 |
+
tool_input = {"query": query, "file_path": file_path}
|
| 162 |
+
print(f"Parsed Excel input - Query: '{query[:50]}...', File: '{file_path}'")
|
| 163 |
+
|
| 164 |
+
# Special handling for math tools: parse "a,b" format
|
| 165 |
+
if tool in (addition_tool, subtraction_tool, multiplication_tool):
|
| 166 |
+
try:
|
| 167 |
+
a, b = tool_input.split(",")
|
| 168 |
+
tool_input = {"a": a.strip(), "b": b.strip()}
|
| 169 |
+
except Exception as e:
|
| 170 |
+
print(f"Error parsing math tool input: {e}")
|
| 171 |
+
return {**state, "tool_output": f"Error parsing input: {e}"}
|
| 172 |
+
|
| 173 |
+
if not tool:
|
| 174 |
+
return {**state, "tool_output": f"Unknown tool: {tool_name}"}
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
result = tool.invoke(tool_input)
|
| 178 |
+
except Exception as e:
|
| 179 |
+
print(f"Error invoking tool {tool_name}: {e}")
|
| 180 |
+
result = f"Tool error: {str(e)}"
|
| 181 |
+
|
| 182 |
+
return {**state, "tool_output": result}
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def update_state_node(state: AgentState):
|
| 186 |
+
"""Update state node: records tool output and progresses to next step."""
|
| 187 |
+
step = state["plan"][state["current_step"]]
|
| 188 |
+
|
| 189 |
+
state["intermediate_results"].append({
|
| 190 |
+
"step": step,
|
| 191 |
+
"output": state["tool_output"]
|
| 192 |
+
})
|
| 193 |
+
|
| 194 |
+
next_step = state["current_step"] + 1
|
| 195 |
+
done = next_step >= len(state["plan"])
|
| 196 |
+
|
| 197 |
+
return {
|
| 198 |
+
**state,
|
| 199 |
+
"current_step": next_step,
|
| 200 |
+
"done": done,
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def should_continue(state: AgentState):
|
| 207 |
+
"""Conditional edge: determines if workflow should continue or finalize."""
|
| 208 |
+
return "finalize" if state["done"] else "continue"
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def finalizer_node(state: AgentState):
|
| 212 |
+
"""Finalizer node: summarizes results and generates final answer."""
|
| 213 |
+
# Format intermediate results for the finalizer
|
| 214 |
+
results_text = "\n".join([
|
| 215 |
+
f"Step {i+1}: {r['step'].get('description', '')}\n Output: {str(r['output'])[:100]}..."
|
| 216 |
+
for i, r in enumerate(state["intermediate_results"])
|
| 217 |
+
])
|
| 218 |
+
|
| 219 |
+
prompt = FINALIZER_PROMPT_TEMPLATE.format(
|
| 220 |
+
question=state['question'],
|
| 221 |
+
intermediate_results=results_text
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
response = llm.invoke(prompt)
|
| 225 |
+
|
| 226 |
+
return {
|
| 227 |
+
**state,
|
| 228 |
+
"final_answer": response.content,
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def create_agent_workflow():
|
| 234 |
+
|
| 235 |
+
graph = StateGraph(AgentState)
|
| 236 |
+
|
| 237 |
+
# Nodes
|
| 238 |
+
graph.add_node("planner", planner_node)
|
| 239 |
+
graph.add_node("executor", execute_step_node)
|
| 240 |
+
graph.add_node("tool", tool_node)
|
| 241 |
+
graph.add_node("updater", update_state_node)
|
| 242 |
+
graph.add_node("finalizer", finalizer_node)
|
| 243 |
+
# Entry
|
| 244 |
+
graph.set_entry_point("planner")
|
| 245 |
+
|
| 246 |
+
# Flow
|
| 247 |
+
graph.add_edge("planner", "executor")
|
| 248 |
+
graph.add_edge("executor", "tool")
|
| 249 |
+
graph.add_edge("tool", "updater")
|
| 250 |
+
|
| 251 |
+
# Loop
|
| 252 |
+
graph.add_conditional_edges(
|
| 253 |
+
"updater",
|
| 254 |
+
should_continue,
|
| 255 |
+
{
|
| 256 |
+
"continue": "executor",
|
| 257 |
+
"finalize": "finalizer"
|
| 258 |
+
}
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# End
|
| 262 |
+
graph.add_edge("finalizer", END)
|
| 263 |
+
|
| 264 |
+
return graph.compile()
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def format_reasoning_trace(intermediate_results: List[Dict[str, Any]]) -> str:
|
| 268 |
+
"""Format intermediate results into a readable reasoning trace"""
|
| 269 |
+
trace_lines = []
|
| 270 |
+
for result in intermediate_results:
|
| 271 |
+
step = result.get("step", {})
|
| 272 |
+
output = result.get("output", "")
|
| 273 |
+
description = step.get("description", "Unknown step")
|
| 274 |
+
tool = step.get("tool", "none")
|
| 275 |
+
|
| 276 |
+
trace_lines.append(f"Step: {description}")
|
| 277 |
+
trace_lines.append(f" Tool: {tool}")
|
| 278 |
+
trace_lines.append(f" Output: {output[:200]}{'...' if len(str(output)) > 200 else ''}")
|
| 279 |
+
|
| 280 |
+
return "\n".join(trace_lines)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def process_questions(questions_file: str = None, questions_list: List[str] = None) -> str:
|
| 284 |
+
"""
|
| 285 |
+
Process multiple questions and save results to a file
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
questions_file: Path to a file containing questions (one per line)
|
| 289 |
+
questions_list: List of questions to process
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
Path to the output file with results
|
| 293 |
+
"""
|
| 294 |
+
global llm
|
| 295 |
+
llm = connect_models()
|
| 296 |
+
print(f"LLM available: {llm}")
|
| 297 |
+
agent = create_agent_workflow()
|
| 298 |
+
|
| 299 |
+
# Get questions from either file or list
|
| 300 |
+
if questions_file:
|
| 301 |
+
with open(questions_file, 'r') as f:
|
| 302 |
+
questions = [q.strip() for q in f.readlines() if q.strip()]
|
| 303 |
+
elif questions_list:
|
| 304 |
+
questions = questions_list
|
| 305 |
+
else:
|
| 306 |
+
raise ValueError("Either questions_file or questions_list must be provided")
|
| 307 |
+
|
| 308 |
+
results = []
|
| 309 |
+
|
| 310 |
+
for idx, question in enumerate(questions, 1):
|
| 311 |
+
task_id = f"task_id_{idx}"
|
| 312 |
+
print(f"\n{'='*80}")
|
| 313 |
+
print(f"Processing {task_id}: {question[:80]}...")
|
| 314 |
+
print(f"{'='*80}")
|
| 315 |
+
|
| 316 |
+
try:
|
| 317 |
+
# Run the agent
|
| 318 |
+
result = agent.invoke({
|
| 319 |
+
"question": question
|
| 320 |
+
})
|
| 321 |
+
|
| 322 |
+
# Extract the final answer and reasoning trace
|
| 323 |
+
final_answer = result.get("final_answer", "No answer generated")
|
| 324 |
+
intermediate_results = result.get("intermediate_results", [])
|
| 325 |
+
|
| 326 |
+
# Format the reasoning trace
|
| 327 |
+
reasoning_trace = format_reasoning_trace(intermediate_results)
|
| 328 |
+
|
| 329 |
+
# Create the result object
|
| 330 |
+
task_result = {
|
| 331 |
+
"task_id": task_id,
|
| 332 |
+
"model_answer": final_answer,
|
| 333 |
+
"reasoning_trace": reasoning_trace
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
results.append(task_result)
|
| 337 |
+
|
| 338 |
+
print(f"Completed {task_id}")
|
| 339 |
+
print(f"Answer: {final_answer[:100]}...")
|
| 340 |
+
|
| 341 |
+
except Exception as e:
|
| 342 |
+
print(f"✗ Error processing {task_id}: {str(e)}")
|
| 343 |
+
task_result = {
|
| 344 |
+
"task_id": task_id,
|
| 345 |
+
"model_answer": f"Error: {str(e)}",
|
| 346 |
+
"reasoning_trace": "Failed to execute agent"
|
| 347 |
+
}
|
| 348 |
+
results.append(task_result)
|
| 349 |
+
|
| 350 |
+
# Save results to file
|
| 351 |
+
output_file = "/home/nitin/AI/hfagent/results.jsonl"
|
| 352 |
+
with open(output_file, 'w') as f:
|
| 353 |
+
for result in results:
|
| 354 |
+
f.write(json.dumps(result) + '\n')
|
| 355 |
+
|
| 356 |
+
print(f"\n{'='*80}")
|
| 357 |
+
print(f"All tasks completed. Results saved to: {output_file}")
|
| 358 |
+
print(f"{'='*80}")
|
| 359 |
+
|
| 360 |
+
return output_file
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
if __name__ == "__main__":
|
| 367 |
+
|
| 368 |
+
global llm
|
| 369 |
+
# Example questions to process
|
| 370 |
+
questions = [
|
| 371 |
+
|
| 372 |
+
#"What is the square of the population of France in millions?",
|
| 373 |
+
#"What is 50 plus 75?"
|
| 374 |
+
]
|
| 375 |
+
|
| 376 |
+
# Process all questions
|
| 377 |
+
output_file = process_questions(questions_list=questions)
|
| 378 |
+
|
| 379 |
+
# Print the results
|
| 380 |
+
print("\nResults from file:")
|
| 381 |
+
with open(output_file, 'r') as f:
|
| 382 |
+
for line in f:
|
| 383 |
+
result = json.loads(line)
|
| 384 |
+
print(f"\nTask ID: {result['task_id']}")
|
| 385 |
+
print(f"Answer: {result['model_answer']}")
|
| 386 |
+
print(f"Reasoning:\n{result['reasoning_trace']}")
|