KnightBlade commited on
Commit
9e89374
·
1 Parent(s): 3b1dc2a

Optimize env concurrency, dataframe ops, and inference loop

Browse files
inference.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  import sys
3
  import asyncio
 
 
4
  from openai import AsyncOpenAI
5
 
6
  # OpenEnv V5 specific client components
@@ -19,6 +21,7 @@ BENCHMARK = "data_wrangler"
19
  MAX_STEPS = 15
20
  MAX_TOTAL_REWARD = 1.0
21
  SUCCESS_SCORE_THRESHOLD = 0.5
 
22
 
23
  system_prompt = """\
24
  SYSTEM INSTRUCTIONS: ELITE DATA ENGINEER AGENT
@@ -63,7 +66,11 @@ Select Action: Which action type and parameters will execute this fix?
63
 
64
  async def get_model_message(client, step, obs_dict, last_reward, history, max_retries=3):
65
  obs_text = str(obs_dict)
66
- prompt = f"Step {step}.\nObservation: {obs_text}\nLast Reward: {last_reward}\nHistory: {history}\nChoose your next action (JSON matching schema)."
 
 
 
 
67
 
68
  # Priority 3: Error Reflection. Pass previous feedback directly to LLM if there was an error.
69
  if "Error" in obs_dict.get("last_action_feedback", "") or "Exception" in obs_dict.get("last_action_feedback", ""):
@@ -77,13 +84,12 @@ async def get_model_message(client, step, obs_dict, last_reward, history, max_re
77
  {"role": "system", "content": system_prompt},
78
  {"role": "user", "content": prompt}
79
  ],
80
- temperature=0.0
 
81
  )
82
  content = response.choices[0].message.content
83
- import json
84
- import re
85
-
86
- match = re.search(r'(\{.*\})', content, re.DOTALL)
87
  if match:
88
  return json.loads(match.group(1))
89
  else:
 
1
  import os
2
  import sys
3
  import asyncio
4
+ import json
5
+ import re
6
  from openai import AsyncOpenAI
7
 
8
  # OpenEnv V5 specific client components
 
21
  MAX_STEPS = 15
22
  MAX_TOTAL_REWARD = 1.0
23
  SUCCESS_SCORE_THRESHOLD = 0.5
24
+ MAX_HISTORY_ITEMS = int(os.environ.get("MAX_HISTORY_ITEMS", "6"))
25
 
26
  system_prompt = """\
27
  SYSTEM INSTRUCTIONS: ELITE DATA ENGINEER AGENT
 
66
 
67
  async def get_model_message(client, step, obs_dict, last_reward, history, max_retries=3):
68
  obs_text = str(obs_dict)
69
+ trimmed_history = history[-MAX_HISTORY_ITEMS:] if history else []
70
+ prompt = (
71
+ f"Step {step}.\nObservation: {obs_text}\nLast Reward: {last_reward}\n"
72
+ f"History: {trimmed_history}\nChoose your next action (JSON matching schema)."
73
+ )
74
 
75
  # Priority 3: Error Reflection. Pass previous feedback directly to LLM if there was an error.
76
  if "Error" in obs_dict.get("last_action_feedback", "") or "Exception" in obs_dict.get("last_action_feedback", ""):
 
84
  {"role": "system", "content": system_prompt},
85
  {"role": "user", "content": prompt}
86
  ],
87
+ temperature=0.0,
88
+ max_tokens=220,
89
  )
90
  content = response.choices[0].message.content
91
+
92
+ match = re.search(r'(\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})', content or "", re.DOTALL)
 
 
93
  if match:
94
  return json.loads(match.group(1))
95
  else:
server/app.py CHANGED
@@ -28,6 +28,8 @@ Usage:
28
  python -m server.app
29
  """
30
 
 
 
31
  try:
32
  from openenv.core.env_server.http_server import create_app
33
  except Exception as e: # pragma: no cover
@@ -52,7 +54,7 @@ app = create_app(
52
  DataWranglerAction,
53
  DataWranglerObservation,
54
  env_name="data_wrangler",
55
- max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
56
  )
57
 
58
 
 
28
  python -m server.app
29
  """
30
 
31
+ import os
32
+
33
  try:
34
  from openenv.core.env_server.http_server import create_app
35
  except Exception as e: # pragma: no cover
 
54
  DataWranglerAction,
55
  DataWranglerObservation,
56
  env_name="data_wrangler",
57
+ max_concurrent_envs=int(os.getenv("MAX_CONCURRENT_ENVS", "4")),
58
  )
59
 
60
 
server/data_wrangler_environment.py CHANGED
@@ -109,10 +109,11 @@ class DataWranglerEnvironment(Environment):
109
  def _get_obs(self, feedback: str = "Environment initialized.", done: bool = False, reward: float = 0.0) -> DataWranglerObservation:
110
  stats = {}
111
  for col in self.df.columns:
 
112
  stats[col] = {
113
  "dtype": str(self.df[col].dtype),
114
  "missing_count": int(self.df[col].isna().sum()),
115
- "sample_values": self.df[col].dropna().astype(str).tolist()[:3]
116
  }
117
 
118
  return DataWranglerObservation(
@@ -132,6 +133,12 @@ class DataWranglerEnvironment(Environment):
132
  self._initialize_task()
133
  return self._get_obs()
134
 
 
 
 
 
 
 
135
  def step(self, action: DataWranglerAction) -> DataWranglerObservation: # type: ignore
136
  self._state.step_count += 1
137
  feedback = "Action executed successfully."
@@ -141,7 +148,8 @@ class DataWranglerEnvironment(Environment):
141
  try:
142
  if action.action_type == "drop_column":
143
  col = action.target_column
144
- if col in self.df.columns:
 
145
  self.df.drop(columns=[col], inplace=True)
146
  if col not in self.target_df.columns:
147
  reward = 0.2
@@ -149,72 +157,81 @@ class DataWranglerEnvironment(Environment):
149
  reward = -0.5
150
  feedback = f"Warning: dropped targeting column {col}"
151
  else:
152
- feedback = f"Error: Column '{col}' not found."
153
 
154
  elif action.action_type == "rename_column":
155
  col = action.target_column
156
  new_col = action.new_name
157
- if col in self.df.columns:
 
158
  self.df.rename(columns={col: new_col}, inplace=True)
159
  if new_col in self.target_df.columns:
160
  reward = 0.2
161
  else:
162
- feedback = f"Error: Column '{col}' not found."
163
 
164
  elif action.action_type == "fill_missing":
165
  col = action.target_column
166
- if col in self.df.columns:
167
- self.df[col].fillna(action.fill_value, inplace=True)
 
168
  reward = 0.1
169
  else:
170
- feedback = f"Error: Column '{col}' not found."
171
 
172
  elif action.action_type == "cast_type":
173
  col = action.target_column
174
- to_type = action.cast_to
175
- if col in self.df.columns:
176
- if to_type == 'int':
177
- self.df = self.df.astype({col: int})
178
- elif to_type == 'float':
179
- self.df = self.df.astype({col: float})
180
- elif to_type == 'datetime':
181
- self.df[col] = pd.to_datetime(self.df[col])
182
- elif to_type == 'string':
183
- self.df = self.df.astype({col: str})
 
 
 
 
184
  reward = 0.2
185
  else:
186
- feedback = f"Error: Column '{col}' not found."
187
 
188
  elif action.action_type == "extract_regex":
189
  col = action.target_column
190
  new_col = action.new_name
191
  pattern = action.regex_pattern
192
- if col in self.df.columns:
 
193
  # extract the first capture group
194
  extracted = self.df[col].astype(str).str.extract(pattern)[0]
195
  self.df[new_col] = extracted
196
  reward = 0.1
197
  else:
198
- feedback = f"Error: Column '{col}' not found."
199
 
200
  elif action.action_type == "datetime_parse":
201
  col = action.target_column
202
  fmt = action.format_string
203
- if col in self.df.columns:
204
- self.df[col] = pd.to_datetime(self.df[col], format=fmt)
 
205
  reward = 0.1
206
  else:
207
- feedback = f"Error: Column '{col}' not found."
208
 
209
  elif action.action_type == "group_by_aggregate":
210
  group_col = action.target_column
211
  agg_col = action.agg_column
212
  func = action.agg_func
213
- if group_col in self.df.columns and agg_col in self.df.columns:
214
- self.df = self.df.groupby(group_col, as_index=False).agg({agg_col: func})
 
215
  reward = 0.2
216
  else:
217
- feedback = f"Error: Columns '{group_col}' or '{agg_col}' not found."
218
 
219
  elif action.action_type == "submit":
220
  score = self._grade()
 
109
  def _get_obs(self, feedback: str = "Environment initialized.", done: bool = False, reward: float = 0.0) -> DataWranglerObservation:
110
  stats = {}
111
  for col in self.df.columns:
112
+ non_null = self.df[col].dropna()
113
  stats[col] = {
114
  "dtype": str(self.df[col].dtype),
115
  "missing_count": int(self.df[col].isna().sum()),
116
+ "sample_values": non_null.astype(str).head(3).tolist(),
117
  }
118
 
119
  return DataWranglerObservation(
 
133
  self._initialize_task()
134
  return self._get_obs()
135
 
136
+ def _require_columns(self, *columns: str) -> str | None:
137
+ missing = [col for col in columns if not col or col not in self.df.columns]
138
+ if missing:
139
+ return f"Error: Column(s) not found: {', '.join(missing)}"
140
+ return None
141
+
142
  def step(self, action: DataWranglerAction) -> DataWranglerObservation: # type: ignore
143
  self._state.step_count += 1
144
  feedback = "Action executed successfully."
 
148
  try:
149
  if action.action_type == "drop_column":
150
  col = action.target_column
151
+ err = self._require_columns(col)
152
+ if not err:
153
  self.df.drop(columns=[col], inplace=True)
154
  if col not in self.target_df.columns:
155
  reward = 0.2
 
157
  reward = -0.5
158
  feedback = f"Warning: dropped targeting column {col}"
159
  else:
160
+ feedback = err
161
 
162
  elif action.action_type == "rename_column":
163
  col = action.target_column
164
  new_col = action.new_name
165
+ err = self._require_columns(col)
166
+ if not err:
167
  self.df.rename(columns={col: new_col}, inplace=True)
168
  if new_col in self.target_df.columns:
169
  reward = 0.2
170
  else:
171
+ feedback = err
172
 
173
  elif action.action_type == "fill_missing":
174
  col = action.target_column
175
+ err = self._require_columns(col)
176
+ if not err:
177
+ self.df[col] = self.df[col].fillna(action.fill_value)
178
  reward = 0.1
179
  else:
180
+ feedback = err
181
 
182
  elif action.action_type == "cast_type":
183
  col = action.target_column
184
+ to_type = (action.cast_to or "").lower()
185
+ err = self._require_columns(col)
186
+ if not err:
187
+ if to_type == "int":
188
+ self.df[col] = pd.to_numeric(self.df[col], errors="coerce").astype("Int64")
189
+ elif to_type == "float":
190
+ self.df[col] = pd.to_numeric(self.df[col], errors="coerce").astype(float)
191
+ elif to_type == "datetime":
192
+ self.df[col] = pd.to_datetime(self.df[col], errors="coerce")
193
+ elif to_type == "string":
194
+ self.df[col] = self.df[col].astype(str)
195
+ else:
196
+ feedback = f"Error: Unsupported cast type '{action.cast_to}'."
197
+ return self._get_obs(feedback=feedback, done=done, reward=reward)
198
  reward = 0.2
199
  else:
200
+ feedback = err
201
 
202
  elif action.action_type == "extract_regex":
203
  col = action.target_column
204
  new_col = action.new_name
205
  pattern = action.regex_pattern
206
+ err = self._require_columns(col)
207
+ if not err and new_col and pattern:
208
  # extract the first capture group
209
  extracted = self.df[col].astype(str).str.extract(pattern)[0]
210
  self.df[new_col] = extracted
211
  reward = 0.1
212
  else:
213
+ feedback = err or "Error: 'new_name' and 'regex_pattern' are required."
214
 
215
  elif action.action_type == "datetime_parse":
216
  col = action.target_column
217
  fmt = action.format_string
218
+ err = self._require_columns(col)
219
+ if not err:
220
+ self.df[col] = pd.to_datetime(self.df[col], format=fmt, errors="coerce")
221
  reward = 0.1
222
  else:
223
+ feedback = err
224
 
225
  elif action.action_type == "group_by_aggregate":
226
  group_col = action.target_column
227
  agg_col = action.agg_column
228
  func = action.agg_func
229
+ err = self._require_columns(group_col, agg_col)
230
+ if not err and func:
231
+ self.df = self.df.groupby(group_col, as_index=False, observed=True).agg({agg_col: func})
232
  reward = 0.2
233
  else:
234
+ feedback = err or "Error: 'agg_func' is required."
235
 
236
  elif action.action_type == "submit":
237
  score = self._grade()