File size: 1,487 Bytes
828c08e
6ece3a3
 
828c08e
 
6ece3a3
 
828c08e
 
 
6ece3a3
 
828c08e
6ece3a3
 
 
 
 
 
 
 
 
 
 
 
828c08e
6ece3a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
828c08e
6ece3a3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import json
from typing import Dict, List
import requests

from gaia_benchmark.questions import load_questions
from gaia_benchmark.run import run_and_submit_all

class GaiaAgent:
    def __init__(self):
        self.api_url = os.environ["HF_MISTRAL_ENDPOINT"]  # Your Mistral endpoint
        self.api_key = os.environ["HF_TOKEN"]             # Hugging Face token
        self.headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
        }

    def generate(self, prompt: str, stop: List[str] = []) -> str:
        payload = {
            "inputs": prompt,
            "parameters": {
                "temperature": 0.0,
                "max_new_tokens": 1024,
                "stop": stop,
            }
        }
        response = requests.post(self.api_url, headers=self.headers, json=payload)
        response.raise_for_status()
        output = response.json()
        if isinstance(output, dict) and "generated_text" in output:
            return output["generated_text"]
        if isinstance(output, list):
            return output[0]["generated_text"]
        return str(output)

    def answer_question(self, question: Dict) -> str:
        question_text = question["question"]
        prompt = f"""You are a helpful agent answering a science question.
Question: {question_text}
Answer:"""
        return self.generate(prompt).strip()

    def run(self):
        run_and_submit_all(self.answer_question)