File size: 2,329 Bytes
828c08e
 
1f9850a
828c08e
92dc164
828c08e
 
 
92dc164
 
 
 
 
 
 
 
 
 
 
828c08e
6ece3a3
 
 
 
 
92dc164
 
 
6ece3a3
 
 
92dc164
 
 
 
828c08e
92dc164
 
 
 
 
 
 
 
6ece3a3
92dc164
 
6ece3a3
 
1f9850a
6ece3a3
92dc164
 
6ece3a3
 
4d888ec
 
 
 
 
3fb74ad
4d888ec
 
3fb74ad
828c08e
6ece3a3
92dc164
57a0245
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import requests
from typing import Dict, List

from run import run_and_submit_all  # Adjust path if needed

class GaiaAgent:
    def __init__(self):
        self.api_url = os.environ.get("HF_MISTRAL_ENDPOINT")
        self.api_key = os.environ.get("HF_TOKEN")
        self.model_id = os.environ.get("LLM_MODEL_ID")

        assert self.api_url, "❌ HF_MISTRAL_ENDPOINT is missing!"
        assert self.api_key, "❌ HF_TOKEN is missing!"
        assert self.model_id, "❌ LLM_MODEL_ID is missing!"

        print(f"βœ… [INIT] Model ID: {self.model_id}")
        print(f"βœ… [INIT] Endpoint: {self.api_url}")

        self.headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
        }

    def generate(self, prompt: str, stop: List[str] = []) -> str:
        print("🧠 [GENERATE] Prompt sent to model:")
        print(prompt)

        payload = {
            "inputs": prompt,
            "parameters": {
                "temperature": 0.0,
                "max_new_tokens": 1024,
                "stop": stop,
            }
        }

        try:
            response = requests.post(self.api_url, headers=self.headers, json=payload)
            response.raise_for_status()
        except Exception as e:
            print(f"❌ [ERROR] Request failed: {e}")
            return "ERROR: Model call failed"

        output = response.json()
        print(f"βœ… [RESPONSE] Raw output: {output}")

        if isinstance(output, dict) and "generated_text" in output:
            return output["generated_text"]
        elif isinstance(output, list) and "generated_text" in output[0]:
            return output[0]["generated_text"]
        else:
            return str(output)

    def answer_question(self, question: Dict) -> str:
        # Try different keys that might contain the question
        q = question.get("question") or question.get("Question") or question.get("input")
        if not q:
            raise ValueError(f"No question text found in: {question}")
    
        prompt = f"""You are a helpful agent answering a science question.
    Question: {q}
    Answer:"""
        return self.generate(prompt).strip()

    def run(self):
        print("πŸš€ [RUN] Starting submission...")
        return run_and_submit_all(self.answer_question)