from typing import Optional, List, Any from .models import Action, QueryResult class RewardCalculator: """Calculate rewards for agent actions in the SQL analyst environment.""" def calculate( self, action: Action, result: Optional[QueryResult], task: Any, step: int, query_history: List[str], terminal: bool, ) -> float: """Calculate reward based on action, result, and task.""" reward = 0.0 if action.sql_query and result: if not result.error: reward += 0.15 relevant = self._count_relevant_tables( action.sql_query, task.relevant_tables ) if relevant > 0: reward += 0.10 if result.rows and len(result.rows) > 0: reward += 0.05 if result.rows and len(result.rows) < 1000: reward += 0.05 if step > 3: reward -= 0.02 * (step - 3) if self._is_stuck(query_history): reward -= 0.10 if terminal and action.submit_answer: task_score = task.grade(action.submit_answer) reward += task_score * 0.60 return max(0.0, min(1.0, reward)) def _count_relevant_tables(self, query: str, relevant_tables: List[str]) -> int: query_lower = query.lower() return sum(1 for t in relevant_tables if t.lower() in query_lower) def _is_stuck(self, history: List[str]) -> bool: if len(history) < 3: return False return len(set(history[-3:])) == 1