File size: 1,607 Bytes
d103a0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from typing import Optional, List, Any
from .models import Action, QueryResult


class RewardCalculator:
    """Calculate rewards for agent actions in the SQL analyst environment."""

    def calculate(
        self,
        action: Action,
        result: Optional[QueryResult],
        task: Any,
        step: int,
        query_history: List[str],
        terminal: bool,
    ) -> float:
        """Calculate reward based on action, result, and task."""
        reward = 0.0

        if action.sql_query and result:
            if not result.error:
                reward += 0.15

            relevant = self._count_relevant_tables(
                action.sql_query, task.relevant_tables
            )
            if relevant > 0:
                reward += 0.10

            if result.rows and len(result.rows) > 0:
                reward += 0.05

            if result.rows and len(result.rows) < 1000:
                reward += 0.05

        if step > 3:
            reward -= 0.02 * (step - 3)

        if self._is_stuck(query_history):
            reward -= 0.10

        if terminal and action.submit_answer:
            task_score = task.grade(action.submit_answer)
            reward += task_score * 0.60

        return max(0.0, min(1.0, reward))

    def _count_relevant_tables(self, query: str, relevant_tables: List[str]) -> int:
        query_lower = query.lower()
        return sum(1 for t in relevant_tables if t.lower() in query_lower)

    def _is_stuck(self, history: List[str]) -> bool:
        if len(history) < 3:
            return False
        return len(set(history[-3:])) == 1