Spaces:
Sleeping
Sleeping
| from typing import Optional, List, Any | |
| from .models import Action, QueryResult | |
| class RewardCalculator: | |
| """Calculate rewards for agent actions in the SQL analyst environment.""" | |
| def calculate( | |
| self, | |
| action: Action, | |
| result: Optional[QueryResult], | |
| task: Any, | |
| step: int, | |
| query_history: List[str], | |
| terminal: bool, | |
| ) -> float: | |
| """Calculate reward based on action, result, and task.""" | |
| reward = 0.0 | |
| if action.sql_query and result: | |
| if not result.error: | |
| reward += 0.15 | |
| relevant = self._count_relevant_tables( | |
| action.sql_query, task.relevant_tables | |
| ) | |
| if relevant > 0: | |
| reward += 0.10 | |
| if result.rows and len(result.rows) > 0: | |
| reward += 0.05 | |
| if result.rows and len(result.rows) < 1000: | |
| reward += 0.05 | |
| if step > 3: | |
| reward -= 0.02 * (step - 3) | |
| if self._is_stuck(query_history): | |
| reward -= 0.10 | |
| if terminal and action.submit_answer: | |
| task_score = task.grade(action.submit_answer) | |
| reward += task_score * 0.60 | |
| return max(0.0, min(1.0, reward)) | |
| def _count_relevant_tables(self, query: str, relevant_tables: List[str]) -> int: | |
| query_lower = query.lower() | |
| return sum(1 for t in relevant_tables if t.lower() in query_lower) | |
| def _is_stuck(self, history: List[str]) -> bool: | |
| if len(history) < 3: | |
| return False | |
| return len(set(history[-3:])) == 1 | |