rahul2124 commited on
Commit
ac49ad8
·
verified ·
1 Parent(s): 72e26c9

Upload folder using huggingface_hub

Browse files
.pytest_cache/v/cache/lastfailed ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
inference.py CHANGED
@@ -182,6 +182,11 @@ def run_task(client: OpenAI, env: SQLArenaEnvironment, task_config: dict) -> flo
182
  break
183
 
184
  final_score = min(max(best_score, 0.0), 1.0)
 
 
 
 
 
185
  success = final_score >= 0.5
186
 
187
  except Exception as e:
 
182
  break
183
 
184
  final_score = min(max(best_score, 0.0), 1.0)
185
+ # Clamp to strictly between 0 and 1
186
+ if final_score <= 0.0:
187
+ final_score = 0.01
188
+ if final_score >= 1.0:
189
+ final_score = 0.99
190
  success = final_score >= 0.5
191
 
192
  except Exception as e:
src/sql_arena/environment.py CHANGED
@@ -136,6 +136,12 @@ class SQLArenaEnvironment:
136
  reward = score * 0.5 + improvement * 0.5
137
 
138
  reward = round(min(max(reward, 0.0), 1.0), 4)
 
 
 
 
 
 
139
  state.rewards_history.append(reward)
140
  state.total_reward += reward
141
 
 
136
  reward = score * 0.5 + improvement * 0.5
137
 
138
  reward = round(min(max(reward, 0.0), 1.0), 4)
139
+ # Clamp to strictly between 0 and 1
140
+ if reward <= 0.0:
141
+ reward = 0.01
142
+ if reward >= 1.0:
143
+ reward = 0.99
144
+
145
  state.rewards_history.append(reward)
146
  state.total_reward += reward
147
 
src/sql_arena/graders.py CHANGED
@@ -194,6 +194,10 @@ def grade_result(
194
 
195
  # ---- Final score ----
196
  score = round(min(max(score, 0.0), 1.0), 4)
 
 
 
 
197
  feedback_parts.append(f"\nTotal Score: {score:.2f}/1.00")
198
 
199
  return score, "\n".join(feedback_parts)
 
194
 
195
  # ---- Final score ----
196
  score = round(min(max(score, 0.0), 1.0), 4)
197
+ if score <= 0.0:
198
+ score = 0.01
199
+ if score >= 1.0:
200
+ score = 0.99
201
  feedback_parts.append(f"\nTotal Score: {score:.2f}/1.00")
202
 
203
  return score, "\n".join(feedback_parts)
tests/test_env.py CHANGED
@@ -32,7 +32,7 @@ class TestEnvironmentBasics:
32
  self.env.reset(difficulty="basic_select", task_id="easy_001")
33
  action = SQLArenaAction(sql_query="INVALID SQL QUERY")
34
  result = self.env.step(action)
35
- assert result.reward == 0.0
36
  assert result.observation.error_message is not None
37
 
38
  def test_state_tracking(self):
 
32
  self.env.reset(difficulty="basic_select", task_id="easy_001")
33
  action = SQLArenaAction(sql_query="INVALID SQL QUERY")
34
  result = self.env.step(action)
35
+ assert result.reward == 0.01 # Clamped to strictly > 0
36
  assert result.observation.error_message is not None
37
 
38
  def test_state_tracking(self):