KnightBlade commited on
Commit
c78c2fe
·
1 Parent(s): 29473f6

feat: add Level 4 task with Regex, Datetime, and GroupBy tools

Browse files
inference.py CHANGED
@@ -16,7 +16,7 @@ MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-3.5-turbo")
16
  IMAGE_NAME = "data_wrangler"
17
  TASK_NAME = "Data Writer Level 1"
18
  BENCHMARK = "data_wrangler"
19
- MAX_STEPS = 10
20
  MAX_TOTAL_REWARD = 1.0
21
  SUCCESS_SCORE_THRESHOLD = 0.5
22
 
@@ -43,7 +43,10 @@ You have a strict, highly constrained toolset. Your chosen action MUST be a vali
43
  2. Rename Column: {"action_type": "rename_column", "target_column": "...", "new_name": "..."}
44
  3. Fill Missing Values: {"action_type": "fill_missing", "target_column": "...", "fill_value": "..."}
45
  4. Cast Data Type: {"action_type": "cast_type", "target_column": "...", "cast_to": "..."}
46
- 5. Submit: {"action_type": "submit"}
 
 
 
47
 
48
  REQUIRED OUTPUT FORMAT (CHAIN OF THOUGHT)
49
  <thinking>
 
16
  IMAGE_NAME = "data_wrangler"
17
  TASK_NAME = "Data Writer Level 1"
18
  BENCHMARK = "data_wrangler"
19
+ MAX_STEPS = 15
20
  MAX_TOTAL_REWARD = 1.0
21
  SUCCESS_SCORE_THRESHOLD = 0.5
22
 
 
43
  2. Rename Column: {"action_type": "rename_column", "target_column": "...", "new_name": "..."}
44
  3. Fill Missing Values: {"action_type": "fill_missing", "target_column": "...", "fill_value": "..."}
45
  4. Cast Data Type: {"action_type": "cast_type", "target_column": "...", "cast_to": "..."}
46
+ 5. Extract Regex: {"action_type": "extract_regex", "target_column": "...", "new_name": "...", "regex_pattern": "..."}
47
+ 6. Parse Datetime: {"action_type": "datetime_parse", "target_column": "...", "format_string": "..."}
48
+ 7. Group By & Aggregate: {"action_type": "group_by_aggregate", "target_column": "...", "agg_column": "...", "agg_func": "sum|mean|count"}
49
+ 8. Submit: {"action_type": "submit"}
50
 
51
  REQUIRED OUTPUT FORMAT (CHAIN OF THOUGHT)
52
  <thinking>
models.py CHANGED
@@ -16,13 +16,17 @@ from pydantic import Field
16
 
17
  class DataWranglerAction(Action):
18
  """Action for the Data Wrangler environment."""
19
- action_type: str = Field(..., description="Type of action: drop_column, rename_column, fill_missing, cast_type, submit")
20
 
21
  # Specifics depending on action_type
22
  target_column: Optional[str] = Field(None, description="The name of the column to act upon.")
23
- new_name: Optional[str] = Field(None, description="New name of the column (for rename_column).")
24
- fill_value: Optional[str] = Field(None, description="Value to fill missing data with (for fill_missing).")
25
  cast_to: Optional[str] = Field(None, description="Target data type (for cast_type, e.g. 'int', 'float', 'datetime', 'string').")
 
 
 
 
26
 
27
  class DataWranglerObservation(Observation):
28
  """Observation representing the state of the dataset."""
 
16
 
17
  class DataWranglerAction(Action):
18
  """Action for the Data Wrangler environment."""
19
+ action_type: str = Field(..., description="Type of action: drop_column, rename_column, fill_missing, cast_type, extract_regex, datetime_parse, group_by_aggregate, submit")
20
 
21
  # Specifics depending on action_type
22
  target_column: Optional[str] = Field(None, description="The name of the column to act upon.")
23
+ new_name: Optional[str] = Field(None, description="New name of the column (for rename_column/extract_regex).")
24
+ fill_value: Optional[str] = Field(None, description="Value to fill missing data with.")
25
  cast_to: Optional[str] = Field(None, description="Target data type (for cast_type, e.g. 'int', 'float', 'datetime', 'string').")
26
+ regex_pattern: Optional[str] = Field(None, description="Regex pattern for extracting data (for extract_regex).")
27
+ format_string: Optional[str] = Field(None, description="Datetime format string (for datetime_parse, e.g., '%Y-%m-%d').")
28
+ agg_column: Optional[str] = Field(None, description="Column to aggregate (for group_by_aggregate).")
29
+ agg_func: Optional[str] = Field(None, description="Aggregation function (e.g., 'mean', 'sum', 'count').")
30
 
31
  class DataWranglerObservation(Observation):
32
  """Observation representing the state of the dataset."""
server/data_wrangler_environment.py CHANGED
@@ -48,7 +48,7 @@ class DataWranglerEnvironment(Environment):
48
  "product_id": [101.0, 102.0, 103.0],
49
  "price": [10.5, 0.0, 12.0]
50
  })
51
- else:
52
  # Hard: Multiple issues
53
  self.df = pd.DataFrame({
54
  "date_joined ": ["2020-01-01", "2021-05-15", None],
@@ -61,6 +61,17 @@ class DataWranglerEnvironment(Environment):
61
  "sales_total": [100.0, 200.0, 300.0],
62
  "is_active": [True, False, False]
63
  })
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def _get_obs(self, feedback: str = "Environment initialized.", done: bool = False, reward: float = 0.0) -> DataWranglerObservation:
66
  stats = {}
@@ -140,6 +151,37 @@ class DataWranglerEnvironment(Environment):
140
  reward = 0.2
141
  else:
142
  feedback = f"Error: Column '{col}' not found."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  elif action.action_type == "submit":
145
  score = self._grade()
 
48
  "product_id": [101.0, 102.0, 103.0],
49
  "price": [10.5, 0.0, 12.0]
50
  })
51
+ elif self.task_level == 3:
52
  # Hard: Multiple issues
53
  self.df = pd.DataFrame({
54
  "date_joined ": ["2020-01-01", "2021-05-15", None],
 
61
  "sales_total": [100.0, 200.0, 300.0],
62
  "is_active": [True, False, False]
63
  })
64
+ else:
65
+ # Level 4: Extremely Hard / Expanded Tools
66
+ self.df = pd.DataFrame({
67
+ "transaction_info": ["TXN-101 (2020/01/15)", "TXN-102 (2020/01/16)", "TXN-103 (2020/01/16)"],
68
+ "customer_category": ["A", "B", "A"],
69
+ "amount": ["$500.5", "$250.0", "$300.0"]
70
+ })
71
+ self.target_df = pd.DataFrame({
72
+ "customer_category": ["A", "B"],
73
+ "amount": [800.5, 250.0] # Sum'd up
74
+ })
75
 
76
  def _get_obs(self, feedback: str = "Environment initialized.", done: bool = False, reward: float = 0.0) -> DataWranglerObservation:
77
  stats = {}
 
151
  reward = 0.2
152
  else:
153
  feedback = f"Error: Column '{col}' not found."
154
+
155
+ elif action.action_type == "extract_regex":
156
+ col = action.target_column
157
+ new_col = action.new_name
158
+ pattern = action.regex_pattern
159
+ if col in self.df.columns:
160
+ # extract the first capture group
161
+ extracted = self.df[col].astype(str).str.extract(pattern)[0]
162
+ self.df[new_col] = extracted
163
+ reward = 0.1
164
+ else:
165
+ feedback = f"Error: Column '{col}' not found."
166
+
167
+ elif action.action_type == "datetime_parse":
168
+ col = action.target_column
169
+ fmt = action.format_string
170
+ if col in self.df.columns:
171
+ self.df[col] = pd.to_datetime(self.df[col], format=fmt)
172
+ reward = 0.1
173
+ else:
174
+ feedback = f"Error: Column '{col}' not found."
175
+
176
+ elif action.action_type == "group_by_aggregate":
177
+ group_col = action.target_column
178
+ agg_col = action.agg_column
179
+ func = action.agg_func
180
+ if group_col in self.df.columns and agg_col in self.df.columns:
181
+ self.df = self.df.groupby(group_col, as_index=False).agg({agg_col: func})
182
+ reward = 0.2
183
+ else:
184
+ feedback = f"Error: Columns '{group_col}' or '{agg_col}' not found."
185
 
186
  elif action.action_type == "submit":
187
  score = self._grade()
tests/test_env.py CHANGED
@@ -1,4 +1,9 @@
 
 
1
  import pytest
 
 
 
2
  from server.data_wrangler_environment import DataWranglerEnvironment
3
  from models import DataWranglerAction
4
 
@@ -12,7 +17,8 @@ def test_environment_reset():
12
  def test_drop_action_scoring():
13
  env = DataWranglerEnvironment()
14
  env.reset()
15
- # It should penalize dropping User Name
 
16
  action = DataWranglerAction(action_type="drop_column", target_column="User Name")
17
  obs = env.step(action)
18
  assert "User Name" not in obs.columns
 
1
+ import os
2
+ import sys
3
  import pytest
4
+
5
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
6
+
7
  from server.data_wrangler_environment import DataWranglerEnvironment
8
  from models import DataWranglerAction
9
 
 
17
  def test_drop_action_scoring():
18
  env = DataWranglerEnvironment()
19
  env.reset()
20
+ env.target_df["User Name"] = [1, 2, 3] # Enforce it post-reset as well
21
+ # It should penalize dropping a column that exists in target_df
22
  action = DataWranglerAction(action_type="drop_column", target_column="User Name")
23
  obs = env.step(action)
24
  assert "User Name" not in obs.columns
tests/test_level4.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
5
+
6
+ from server.data_wrangler_environment import DataWranglerEnvironment
7
+ from models import DataWranglerAction
8
+
9
+ def test_level_4():
10
+ os.environ["TASK_LEVEL"] = "4"
11
+ env = DataWranglerEnvironment()
12
+ env.reset()
13
+
14
+ # 1. regex to extract number.
15
+ # Strings: "$500.5" -> pattern "^\$?([0-9.]+)"
16
+ env.step(DataWranglerAction(action_type="extract_regex", target_column="amount", new_name="amount", regex_pattern=r"^\$?([0-9.]+)"))
17
+
18
+ # 2. Cast amount to float
19
+ env.step(DataWranglerAction(action_type="cast_type", target_column="amount", cast_to="float"))
20
+
21
+ # 3. Group by customer_category and sum amount
22
+ env.step(DataWranglerAction(action_type="group_by_aggregate", target_column="customer_category", agg_column="amount", agg_func="sum"))
23
+
24
+ obs = env.step(DataWranglerAction(action_type="submit"))
25
+ assert obs.reward > 0.8 # It should grade highly because the DFs will match exactly!