Spaces:
Sleeping
Sleeping
File size: 1,607 Bytes
d103a0f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | from typing import Optional, List, Any
from .models import Action, QueryResult
class RewardCalculator:
"""Calculate rewards for agent actions in the SQL analyst environment."""
def calculate(
self,
action: Action,
result: Optional[QueryResult],
task: Any,
step: int,
query_history: List[str],
terminal: bool,
) -> float:
"""Calculate reward based on action, result, and task."""
reward = 0.0
if action.sql_query and result:
if not result.error:
reward += 0.15
relevant = self._count_relevant_tables(
action.sql_query, task.relevant_tables
)
if relevant > 0:
reward += 0.10
if result.rows and len(result.rows) > 0:
reward += 0.05
if result.rows and len(result.rows) < 1000:
reward += 0.05
if step > 3:
reward -= 0.02 * (step - 3)
if self._is_stuck(query_history):
reward -= 0.10
if terminal and action.submit_answer:
task_score = task.grade(action.submit_answer)
reward += task_score * 0.60
return max(0.0, min(1.0, reward))
def _count_relevant_tables(self, query: str, relevant_tables: List[str]) -> int:
query_lower = query.lower()
return sum(1 for t in relevant_tables if t.lower() in query_lower)
def _is_stuck(self, history: List[str]) -> bool:
if len(history) < 3:
return False
return len(set(history[-3:])) == 1
|