Nitien commited on
Commit
e93fa76
·
verified ·
1 Parent(s): 9a976c0

Update gaia_agent.py

Browse files
Files changed (1) hide show
  1. gaia_agent.py +47 -12
gaia_agent.py CHANGED
@@ -14,6 +14,7 @@ from langchain_core.messages import HumanMessage
14
  from langchain_openrouter import ChatOpenRouter
15
  from langgraph.graph import StateGraph, START, END
16
  from langgraph.checkpoint.memory import MemorySaver
 
17
  from typing import TypedDict
18
 
19
  from customtools import (
@@ -26,10 +27,14 @@ from customtools import (
26
  subtraction_tool,
27
  multiplication_tool,
28
  transcribe_audio,
 
29
  )
30
  from config import (
31
  OPENROUTER_API_KEY,
32
  LLM_MODEL,
 
 
 
33
  LLM_TEMPERATURE,
34
  OUTPUT_FILE,
35
  FINAL_ANSWER_MAX_LENGTH,
@@ -48,12 +53,24 @@ memory = MemorySaver()
48
  def connect_models():
49
  """Initialize and return the LLM instance."""
50
  try:
51
- print(f"Connecting to LLM: {LLM_MODEL}")
52
- llm = ChatOpenRouter(
53
- model=LLM_MODEL,
54
- temperature=LLM_TEMPERATURE,
55
- api_key=OPENROUTER_API_KEY,
56
- )
 
 
 
 
 
 
 
 
 
 
 
 
57
  return llm
58
  except Exception as e:
59
  print(f"Error initializing LLM: {e}")
@@ -71,6 +88,7 @@ TOOLS = {
71
  "extract_text_from_image": extract_text_from_image,
72
  "wikisearch": wikisearch,
73
  "transcribe_audio": transcribe_audio,
 
74
  }
75
 
76
 
@@ -101,6 +119,7 @@ class Step(BaseModel):
101
  "addition_tool",
102
  "subtraction_tool",
103
  "multiplication_tool",
 
104
  "none",
105
  ]
106
  tool_input: str
@@ -131,6 +150,7 @@ def planner_node(state: AgentState):
131
  def execute_step_node(state: AgentState):
132
  """Execute step node: prepares tool invocation."""
133
  step = state["plan"][state["current_step"]]
 
134
  tool_name = step.get("tool", "none")
135
 
136
  print(f"Executing step {state['current_step'] + 1}/{len(state['plan'])}: {tool_name}")
@@ -280,10 +300,6 @@ def format_reasoning_trace(intermediate_results: List[Dict[str, Any]]) -> str:
280
  return "\n".join(trace_lines)
281
 
282
 
283
- ##################################################################################################################
284
- # For Local env.
285
- ##################################################################################################################
286
-
287
  def process_questions(questions_file: str = None, questions_list: List[str] = None) -> str:
288
  """
289
  Process multiple questions and save results to a file
@@ -367,7 +383,26 @@ def process_questions(questions_file: str = None, questions_list: List[str] = No
367
 
368
 
369
 
370
- #if __name__ == "__main__":
 
 
 
 
 
 
371
 
372
- # global llm
 
 
 
 
 
373
 
 
 
 
 
 
 
 
 
 
14
  from langchain_openrouter import ChatOpenRouter
15
  from langgraph.graph import StateGraph, START, END
16
  from langgraph.checkpoint.memory import MemorySaver
17
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA
18
  from typing import TypedDict
19
 
20
  from customtools import (
 
27
  subtraction_tool,
28
  multiplication_tool,
29
  transcribe_audio,
30
+ modulus_tool,
31
  )
32
  from config import (
33
  OPENROUTER_API_KEY,
34
  LLM_MODEL,
35
+ NVIDIA,
36
+ NVIDIA_API_KEY,
37
+ NVIDIA_MODEL,
38
  LLM_TEMPERATURE,
39
  OUTPUT_FILE,
40
  FINAL_ANSWER_MAX_LENGTH,
 
53
  def connect_models():
54
  """Initialize and return the LLM instance."""
55
  try:
56
+ global llm
57
+ if NVIDIA:
58
+ llm = ChatNVIDIA(
59
+ model=NVIDIA_MODEL,
60
+ api_key= NVIDIA_API_KEY,
61
+ temperature=1,
62
+ top_p=1,
63
+
64
+ )
65
+
66
+
67
+ else:
68
+ print(f"Connecting to LLM: {LLM_MODEL}")
69
+ llm = ChatOpenRouter(
70
+ model=LLM_MODEL,
71
+ temperature=LLM_TEMPERATURE,
72
+ api_key=OPENROUTER_API_KEY,
73
+ )
74
  return llm
75
  except Exception as e:
76
  print(f"Error initializing LLM: {e}")
 
88
  "extract_text_from_image": extract_text_from_image,
89
  "wikisearch": wikisearch,
90
  "transcribe_audio": transcribe_audio,
91
+ "modulus_tool": modulus_tool,
92
  }
93
 
94
 
 
119
  "addition_tool",
120
  "subtraction_tool",
121
  "multiplication_tool",
122
+ "modulus_tool",
123
  "none",
124
  ]
125
  tool_input: str
 
150
  def execute_step_node(state: AgentState):
151
  """Execute step node: prepares tool invocation."""
152
  step = state["plan"][state["current_step"]]
153
+ print(f"Current Step:{step}")
154
  tool_name = step.get("tool", "none")
155
 
156
  print(f"Executing step {state['current_step'] + 1}/{len(state['plan'])}: {tool_name}")
 
300
  return "\n".join(trace_lines)
301
 
302
 
 
 
 
 
303
  def process_questions(questions_file: str = None, questions_list: List[str] = None) -> str:
304
  """
305
  Process multiple questions and save results to a file
 
383
 
384
 
385
 
386
+ if __name__ == "__main__":
387
+
388
+ global llm
389
+ # Example questions to process
390
+ questions = [
391
+ """Task ID: 52e8ce1c-09bd-4537-8e2d-67d1648779b9 ; Question: The attached .csv file shows precipitation amounts, in inches, for the five boroughs of New York City in a certain year. How many inches of precipitation did the city receive in total for that year? Don’t use commas if the number has four or more digits. ; file_name: /home/nitin/.cache/huggingface/hub/datasets--gaia-benchmark--GAIA/snapshots/682dd723ee1e1697e00360edccf2366dc8418dd9/2023/test/52e8ce1c-09bd-4537-8e2d-67d1648779b9.csv
392
+ """
393
 
394
+ #"What is the square of the population of France in millions?",
395
+ #"What is 50 plus 75?"
396
+ ]
397
+
398
+ # Process all questions
399
+ output_file = process_questions(questions_list=questions)
400
 
401
+ # Print the results
402
+ print("\nResults from file:")
403
+ with open(output_file, 'r') as f:
404
+ for line in f:
405
+ result = json.loads(line)
406
+ print(f"\nTask ID: {result['task_id']}")
407
+ print(f"Answer: {result['model_answer']}")
408
+ print(f"Reasoning:\n{result['reasoning_trace']}")