YashashMathur commited on
Commit
611be05
·
verified ·
1 Parent(s): 5f88c3f

Fix grader scores strictly between 0 and 1

Browse files
Files changed (1) hide show
  1. environment/graders.py +228 -0
environment/graders.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # environment/graders.py
2
+ # Deterministic grading system for SQL Data Analyst environment
3
+ # Implements type-agnostic normalization and SQL evaluation
4
+
5
+ from typing import Any, Tuple, Optional
6
+ import re
7
+
8
+
9
+ def normalize_value(value: Any) -> str:
10
+ """
11
+ Normalize a value for comparison.
12
+
13
+ Type-Agnostic Normalization:
14
+ - Strip whitespace
15
+ - Lowercase strings
16
+ - Handle numeric conversions
17
+
18
+ Args:
19
+ value: Any value to normalize
20
+
21
+ Returns:
22
+ str: Normalized string representation
23
+ """
24
+ if value is None:
25
+ return ""
26
+
27
+ # Convert to string first
28
+ str_value = str(value).strip().lower()
29
+
30
+ # Remove extra whitespace
31
+ str_value = re.sub(r"\s+", " ", str_value)
32
+
33
+ # Try to normalize numeric values
34
+ try:
35
+ # Try float first
36
+ float_val = float(str_value)
37
+ # Round to 2 decimal places for comparison
38
+ return str(round(float_val, 2))
39
+ except (ValueError, TypeError):
40
+ pass
41
+
42
+ return str_value
43
+
44
+
45
+ def extract_numeric(value: str) -> Optional[float]:
46
+ """
47
+ Extract a numeric value from a string.
48
+
49
+ Args:
50
+ value: String that may contain a number
51
+
52
+ Returns:
53
+ Optional[float]: Extracted number or None
54
+ """
55
+ # Remove common formatting
56
+ cleaned = re.sub(r"[$,]", "", str(value).strip())
57
+
58
+ try:
59
+ return float(cleaned)
60
+ except (ValueError, TypeError):
61
+ return None
62
+
63
+
64
+ def compare_values(submitted: Any, ground_truth: Any) -> Tuple[bool, float]:
65
+ """
66
+ Compare submitted answer to ground truth.
67
+
68
+ Args:
69
+ submitted: The agent's submitted answer
70
+ ground_truth: The expected correct answer
71
+
72
+ Returns:
73
+ Tuple[bool, float]: (is_correct, score)
74
+ - is_correct: True if answer matches
75
+ - score: Value strictly between 0.0 and 1.0
76
+ """
77
+ # Normalize both values
78
+ norm_submitted = normalize_value(submitted)
79
+ norm_truth = normalize_value(ground_truth)
80
+
81
+ # Direct string comparison after normalization
82
+ if norm_submitted == norm_truth:
83
+ return True, 0.99
84
+
85
+ # Try numeric comparison for numeric ground truths
86
+ if isinstance(ground_truth, (int, float)):
87
+ submitted_num = extract_numeric(submitted)
88
+ if submitted_num is not None:
89
+ truth_num = float(ground_truth)
90
+ # Allow small floating point tolerance
91
+ if abs(submitted_num - truth_num) < 0.01:
92
+ return True, 0.99
93
+ # Partial credit for being close (within 10%)
94
+ if truth_num != 0:
95
+ error_pct = abs(submitted_num - truth_num) / abs(truth_num)
96
+ if error_pct < 0.1:
97
+ return False, 0.05
98
+
99
+ # Check if submitted answer contains the ground truth
100
+ if norm_truth in norm_submitted:
101
+ return True, 0.99
102
+
103
+ return False, 0.01
104
+
105
+
106
+ def grade_sql_result(
107
+ query_result: str, ground_truth: Any, is_error: bool
108
+ ) -> Tuple[bool, float]:
109
+ """
110
+ Grade a SQL query result against ground truth.
111
+
112
+ Args:
113
+ query_result: The result string from executing the SQL query
114
+ ground_truth: The expected correct answer
115
+ is_error: Whether the query execution resulted in an error
116
+
117
+ Returns:
118
+ Tuple[bool, float]: (is_correct, score) - score strictly between 0.0 and 1.0
119
+ """
120
+ if is_error:
121
+ return False, 0.01
122
+
123
+ lines = query_result.strip().split("\n")
124
+
125
+ data_lines = [l for l in lines if l.strip() and not l.startswith("|---")]
126
+
127
+ if len(data_lines) < 2:
128
+ return False, 0.01
129
+
130
+ data_row = data_lines[1] if len(data_lines) > 1 else ""
131
+
132
+ values = [v.strip() for v in data_row.split("|") if v.strip()]
133
+
134
+ if not values:
135
+ return False, 0.01
136
+
137
+ for value in values:
138
+ is_correct, score = compare_values(value, ground_truth)
139
+ if is_correct:
140
+ return True, score
141
+
142
+ return False, 0.01
143
+
144
+ # Parse the query result to extract values
145
+ # Result format is markdown table: | col1 | col2 |
146
+ lines = query_result.strip().split("\n")
147
+
148
+ # Skip header and separator lines
149
+ data_lines = [l for l in lines if l.strip() and not l.startswith("|---")]
150
+
151
+ if len(data_lines) < 2: # Need at least header + 1 data row
152
+ return False, 0.0
153
+
154
+ # Get the first data row (skip header)
155
+ data_row = data_lines[1] if len(data_lines) > 1 else ""
156
+
157
+ # Extract values from the row
158
+ values = [v.strip() for v in data_row.split("|") if v.strip()]
159
+
160
+ if not values:
161
+ return False, 0.0
162
+
163
+ # For single-value answers, compare the first value
164
+ # For multi-column results, try each value
165
+ for value in values:
166
+ is_correct, score = compare_values(value, ground_truth)
167
+ if is_correct:
168
+ return True, score
169
+
170
+ return False, 0.0
171
+
172
+
173
+ def grade_answer(
174
+ submitted_answer: str, ground_truth: Any, db_engine: Any = None
175
+ ) -> Tuple[bool, float]:
176
+ """
177
+ Grade the agent's submitted answer.
178
+
179
+ Args:
180
+ submitted_answer: The agent's submitted answer string
181
+ ground_truth: The expected correct answer
182
+ db_engine: Optional database engine for SQL evaluation
183
+
184
+ Returns:
185
+ Tuple[bool, float]: (is_correct, score) - score strictly between 0.0 and 1.0
186
+ """
187
+ if not submitted_answer or not submitted_answer.strip():
188
+ return False, 0.01
189
+
190
+ submitted = submitted_answer.strip()
191
+
192
+ sql_keywords = ["SELECT", "FROM", "WHERE", "JOIN", "GROUP", "ORDER"]
193
+ is_sql_query = any(keyword in submitted.upper() for keyword in sql_keywords)
194
+
195
+ if is_sql_query and db_engine is not None:
196
+ result, is_error = db_engine.execute_query(submitted)
197
+ return grade_sql_result(result, ground_truth, is_error)
198
+
199
+ return compare_values(submitted, ground_truth)
200
+
201
+
202
+ def calculate_final_score(
203
+ is_correct: bool, total_steps: int, max_steps: int = 15
204
+ ) -> float:
205
+ """
206
+ Calculate the final score for a task.
207
+
208
+ Args:
209
+ is_correct: Whether the answer was correct
210
+ total_steps: Number of steps taken
211
+ max_steps: Maximum allowed steps
212
+
213
+ Returns:
214
+ float: Final score strictly between 0.0 and 1.0
215
+ """
216
+ if not is_correct:
217
+ return 0.01
218
+
219
+ base_score = 0.7
220
+
221
+ efficiency_ratio = 1.0 - (total_steps / max_steps)
222
+ efficiency_bonus = max(0.0, efficiency_ratio * 0.3)
223
+
224
+ final_score = base_score + efficiency_bonus
225
+
226
+ # Ensure score is strictly between 0.0 and 1.0
227
+ # Use 0.99 as max to stay strictly under 1.0
228
+ return min(0.99, max(0.01, final_score))