File size: 2,183 Bytes
828c08e
 
1f9850a
828c08e
92dc164
828c08e
 
 
92dc164
 
 
 
 
 
 
 
 
 
 
828c08e
6ece3a3
 
 
 
 
92dc164
 
 
6ece3a3
 
 
92dc164
 
 
 
828c08e
92dc164
 
 
 
 
 
 
 
6ece3a3
92dc164
 
6ece3a3
 
1f9850a
6ece3a3
92dc164
 
6ece3a3
 
1f9850a
92dc164
 
6ece3a3
1f9850a
6ece3a3
92dc164
6ece3a3
828c08e
6ece3a3
92dc164
57a0245
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import requests
from typing import Dict, List

from run import run_and_submit_all  # Adjust path if needed

class GaiaAgent:
    def __init__(self):
        self.api_url = os.environ.get("HF_MISTRAL_ENDPOINT")
        self.api_key = os.environ.get("HF_TOKEN")
        self.model_id = os.environ.get("LLM_MODEL_ID")

        assert self.api_url, "❌ HF_MISTRAL_ENDPOINT is missing!"
        assert self.api_key, "❌ HF_TOKEN is missing!"
        assert self.model_id, "❌ LLM_MODEL_ID is missing!"

        print(f"βœ… [INIT] Model ID: {self.model_id}")
        print(f"βœ… [INIT] Endpoint: {self.api_url}")

        self.headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
        }

    def generate(self, prompt: str, stop: List[str] = []) -> str:
        print("🧠 [GENERATE] Prompt sent to model:")
        print(prompt)

        payload = {
            "inputs": prompt,
            "parameters": {
                "temperature": 0.0,
                "max_new_tokens": 1024,
                "stop": stop,
            }
        }

        try:
            response = requests.post(self.api_url, headers=self.headers, json=payload)
            response.raise_for_status()
        except Exception as e:
            print(f"❌ [ERROR] Request failed: {e}")
            return "ERROR: Model call failed"

        output = response.json()
        print(f"βœ… [RESPONSE] Raw output: {output}")

        if isinstance(output, dict) and "generated_text" in output:
            return output["generated_text"]
        elif isinstance(output, list) and "generated_text" in output[0]:
            return output[0]["generated_text"]
        else:
            return str(output)

    def answer_question(self, question: Dict) -> str:
        q = question["question"]
        print(f"πŸ“Œ [QUESTION] ID: {question.get('id', 'N/A')} - {q}")

        prompt = f"""You are a helpful agent answering a science question.
Question: {q}
Answer:"""

        return self.generate(prompt).strip()

    def run(self):
        print("πŸš€ [RUN] Starting submission...")
        return run_and_submit_all(self.answer_question)