File size: 1,378 Bytes
828c08e
 
1f9850a
828c08e
6ece3a3
828c08e
 
 
1f9850a
 
828c08e
6ece3a3
 
 
 
 
 
 
 
 
 
 
 
828c08e
6ece3a3
 
 
1f9850a
6ece3a3
 
1f9850a
6ece3a3
 
 
 
1f9850a
6ece3a3
1f9850a
6ece3a3
 
828c08e
6ece3a3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import requests
from typing import Dict, List

from gaia_benchmark.run import run_and_submit_all

class GaiaAgent:
    def __init__(self):
        self.api_url = os.environ["HF_MISTRAL_ENDPOINT"]
        self.api_key = os.environ["HF_TOKEN"]
        self.headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
        }

    def generate(self, prompt: str, stop: List[str] = []) -> str:
        payload = {
            "inputs": prompt,
            "parameters": {
                "temperature": 0.0,
                "max_new_tokens": 1024,
                "stop": stop,
            }
        }
        response = requests.post(self.api_url, headers=self.headers, json=payload)
        response.raise_for_status()
        output = response.json()

        if isinstance(output, dict) and "generated_text" in output:
            return output["generated_text"]
        elif isinstance(output, list) and "generated_text" in output[0]:
            return output[0]["generated_text"]
        return str(output)

    def answer_question(self, question: Dict) -> str:
        q = question["question"]
        prompt = f"""You are a helpful agent answering a science question.
Question: {q}
Answer:"""
        return self.generate(prompt).strip()

    def run(self):
        run_and_submit_all(self.answer_question)