sql_data_analyst / env /reward.py
YashashMathur's picture
SQL Data Analyst OpenEnv - Initial commit
d103a0f verified
raw
history blame
1.61 kB
from typing import Optional, List, Any
from .models import Action, QueryResult
class RewardCalculator:
"""Calculate rewards for agent actions in the SQL analyst environment."""
def calculate(
self,
action: Action,
result: Optional[QueryResult],
task: Any,
step: int,
query_history: List[str],
terminal: bool,
) -> float:
"""Calculate reward based on action, result, and task."""
reward = 0.0
if action.sql_query and result:
if not result.error:
reward += 0.15
relevant = self._count_relevant_tables(
action.sql_query, task.relevant_tables
)
if relevant > 0:
reward += 0.10
if result.rows and len(result.rows) > 0:
reward += 0.05
if result.rows and len(result.rows) < 1000:
reward += 0.05
if step > 3:
reward -= 0.02 * (step - 3)
if self._is_stuck(query_history):
reward -= 0.10
if terminal and action.submit_answer:
task_score = task.grade(action.submit_answer)
reward += task_score * 0.60
return max(0.0, min(1.0, reward))
def _count_relevant_tables(self, query: str, relevant_tables: List[str]) -> int:
query_lower = query.lower()
return sum(1 for t in relevant_tables if t.lower() in query_lower)
def _is_stuck(self, history: List[str]) -> bool:
if len(history) < 3:
return False
return len(set(history[-3:])) == 1