KnightBlade commited on
Commit
29473f6
·
1 Parent(s): 8d27c3e

feat: Priority 2-4 implementations

Browse files
.github/workflows/ci.yml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Data Wrangler CI/CD
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ pull_request:
8
+ branches:
9
+ - main
10
+
11
+ jobs:
12
+ test:
13
+ runs-on: ubuntu-latest
14
+ steps:
15
+ - name: Checkout Code
16
+ uses: actions/checkout@v3
17
+
18
+ - name: Set up Python
19
+ uses: actions/setup-python@v4
20
+ with:
21
+ python-version: '3.11'
22
+
23
+ - name: Install Dependencies
24
+ run: |
25
+ python -m pip install --upgrade pip
26
+ pip install -r server/requirements.txt
27
+ pip install pytest openenv
28
+
29
+ - name: Run Tests
30
+ run: |
31
+ pytest tests/ -v
32
+
33
+ deploy_hf_space:
34
+ needs: test
35
+ if: github.ref == 'refs/heads/main'
36
+ runs-on: ubuntu-latest
37
+ steps:
38
+ - name: Checkout Code
39
+ uses: actions/checkout@v3
40
+ with:
41
+ fetch-depth: 0
42
+
43
+ - name: Push to Hugging Face
44
+ env:
45
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
46
+ run: |
47
+ git config --global user.email "github-actions[bot]@users.noreply.github.com"
48
+ git config --global user.name "github-actions[bot]"
49
+ git remote add hf https://user:$HF_TOKEN@huggingface.co/spaces/KnightBlade/data_wrangler
50
+ git push -f hf main
inference.py CHANGED
@@ -58,29 +58,38 @@ Select Action: Which action type and parameters will execute this fix?
58
  }
59
  """
60
 
61
- async def get_model_message(client, step, obs_dict, last_reward, history):
62
  obs_text = str(obs_dict)
63
  prompt = f"Step {step}.\nObservation: {obs_text}\nLast Reward: {last_reward}\nHistory: {history}\nChoose your next action (JSON matching schema)."
64
- try:
65
- response = await client.chat.completions.create(
66
- model=MODEL_NAME,
67
- messages=[
68
- {"role": "system", "content": system_prompt},
69
- {"role": "user", "content": prompt}
70
- ],
71
- temperature=0.0
72
- )
73
- content = response.choices[0].message.content
74
- import json
75
- import re
76
- # Basic parsing of the JSON structure that follows the thinking tags
77
- match = re.search(r'(\{.*\})', content, re.DOTALL)
78
- if match:
79
- return json.loads(match.group(1))
80
- # Fallback if unparseable
81
- return {"action_type": "submit"}
82
- except Exception as e:
83
- return {"action_type": "submit"}
 
 
 
 
 
 
 
 
 
84
 
85
  def log_start(task, env, model):
86
  print(f"[START] task={task} env={env} model={model}")
 
58
  }
59
  """
60
 
61
+ async def get_model_message(client, step, obs_dict, last_reward, history, max_retries=3):
62
  obs_text = str(obs_dict)
63
  prompt = f"Step {step}.\nObservation: {obs_text}\nLast Reward: {last_reward}\nHistory: {history}\nChoose your next action (JSON matching schema)."
64
+
65
+ # Priority 3: Error Reflection. Pass previous feedback directly to LLM if there was an error.
66
+ if "Error" in obs_dict.get("last_action_feedback", "") or "Exception" in obs_dict.get("last_action_feedback", ""):
67
+ prompt += f"\nCRITICAL: Your last action failed with this error: {obs_dict['last_action_feedback']}. Review your <thinking> block to correct your mistake before trying a new action."
68
+
69
+ for attempt in range(max_retries):
70
+ try:
71
+ response = await client.chat.completions.create(
72
+ model=MODEL_NAME,
73
+ messages=[
74
+ {"role": "system", "content": system_prompt},
75
+ {"role": "user", "content": prompt}
76
+ ],
77
+ temperature=0.0
78
+ )
79
+ content = response.choices[0].message.content
80
+ import json
81
+ import re
82
+
83
+ match = re.search(r'(\{.*\})', content, re.DOTALL)
84
+ if match:
85
+ return json.loads(match.group(1))
86
+ else:
87
+ prompt += f"\nWarning: Failed to extract JSON on attempt {attempt+1}. Provide ONLY valid JSON inside curly braces."
88
+ except Exception as e:
89
+ prompt += f"\nWarning: Exception on attempt {attempt+1}: {str(e)}. Provide valid JSON."
90
+
91
+ # Fallback only if absolutely all retries fail
92
+ return {"action_type": "submit"}
93
 
94
  def log_start(task, env, model):
95
  print(f"[START] task={task} env={env} model={model}")
server/data_wrangler_environment.py CHANGED
@@ -157,21 +157,36 @@ class DataWranglerEnvironment(Environment):
157
 
158
  def _grade(self) -> float:
159
  score = 0.0
160
- if list(self.df.columns) == list(self.target_df.columns):
161
- score += 0.5
162
- # Match types and values
163
- value_matches = 0
164
- for col in self.df.columns:
 
 
 
165
  try:
166
- # simple match check
167
- match = (self.df[col] == self.target_df[col]).all()
168
- if match:
169
- value_matches += 1
 
170
  except:
171
  pass
172
- score += 0.5 * (value_matches / max(len(self.target_df.columns), 1))
 
 
 
 
 
 
 
 
 
 
 
173
 
174
- return score
175
 
176
  @property
177
  def state(self) -> State:
 
157
 
158
  def _grade(self) -> float:
159
  score = 0.0
160
+
161
+ # Priority 2: Partial credit per correct column (name + dtype + values)
162
+ correct_columns = 0
163
+ target_cols = set(self.target_df.columns)
164
+ current_cols = set(self.df.columns)
165
+
166
+ for col in target_cols:
167
+ if col in current_cols:
168
  try:
169
+ # Check dtype match
170
+ if self.df[col].dtype == self.target_df[col].dtype:
171
+ # Check value match
172
+ if (self.df[col].equals(self.target_df[col])):
173
+ correct_columns += 1
174
  except:
175
  pass
176
+
177
+ # Max score from matching columns is 0.8 (leaving 0.2 for efficiency)
178
+ column_score = (correct_columns / max(len(target_cols), 1)) * 0.8
179
+ score += column_score
180
+
181
+ # Priority 2: Step efficiency bonus
182
+ # If solved in few steps, give up to 0.2 bonus
183
+ ideal_steps = len(target_cols) # rough estimate
184
+ if self._state.step_count <= ideal_steps + 2:
185
+ score += 0.2
186
+ elif self._state.step_count <= ideal_steps + 5:
187
+ score += 0.1
188
 
189
+ return min(max(score, 0.0), 1.0)
190
 
191
  @property
192
  def state(self) -> State:
tests/test_env.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from server.data_wrangler_environment import DataWranglerEnvironment
3
+ from models import DataWranglerAction
4
+
5
+ def test_environment_reset():
6
+ env = DataWranglerEnvironment()
7
+ obs = env.reset()
8
+ assert obs.columns == ["User Name", "Unnamed: 0", "Age"]
9
+ assert obs.row_count == 3
10
+ assert not obs.is_done
11
+
12
+ def test_drop_action_scoring():
13
+ env = DataWranglerEnvironment()
14
+ env.reset()
15
+ # It should penalize dropping User Name
16
+ action = DataWranglerAction(action_type="drop_column", target_column="User Name")
17
+ obs = env.step(action)
18
+ assert "User Name" not in obs.columns
19
+ assert "Warning" in obs.last_action_feedback or "Error" in obs.last_action_feedback
20
+
21
+ def test_successful_grading():
22
+ import os
23
+ os.environ["TASK_LEVEL"] = "1"
24
+ env = DataWranglerEnvironment()
25
+ env.reset()
26
+
27
+ # 1. Drop Unnamed: 0
28
+ env.step(DataWranglerAction(action_type="drop_column", target_column="Unnamed: 0"))
29
+
30
+ # 2. Rename User Name
31
+ env.step(DataWranglerAction(action_type="rename_column", target_column="User Name", new_name="user_name"))
32
+
33
+ # 3. Rename Age
34
+ env.step(DataWranglerAction(action_type="rename_column", target_column="Age", new_name="age"))
35
+
36
+ # 4. Submit
37
+ obs = env.step(DataWranglerAction(action_type="submit"))
38
+ assert obs.is_done
39
+ assert obs.reward > 0.8 # partial credit + efficiency