Spaces:
Runtime error
Runtime error
Commit ·
9e89374
1
Parent(s): 3b1dc2a
Optimize env concurrency, dataframe ops, and inference loop
Browse files- inference.py +12 -6
- server/app.py +3 -1
- server/data_wrangler_environment.py +44 -27
inference.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
import asyncio
|
|
|
|
|
|
|
| 4 |
from openai import AsyncOpenAI
|
| 5 |
|
| 6 |
# OpenEnv V5 specific client components
|
|
@@ -19,6 +21,7 @@ BENCHMARK = "data_wrangler"
|
|
| 19 |
MAX_STEPS = 15
|
| 20 |
MAX_TOTAL_REWARD = 1.0
|
| 21 |
SUCCESS_SCORE_THRESHOLD = 0.5
|
|
|
|
| 22 |
|
| 23 |
system_prompt = """\
|
| 24 |
SYSTEM INSTRUCTIONS: ELITE DATA ENGINEER AGENT
|
|
@@ -63,7 +66,11 @@ Select Action: Which action type and parameters will execute this fix?
|
|
| 63 |
|
| 64 |
async def get_model_message(client, step, obs_dict, last_reward, history, max_retries=3):
|
| 65 |
obs_text = str(obs_dict)
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
# Priority 3: Error Reflection. Pass previous feedback directly to LLM if there was an error.
|
| 69 |
if "Error" in obs_dict.get("last_action_feedback", "") or "Exception" in obs_dict.get("last_action_feedback", ""):
|
|
@@ -77,13 +84,12 @@ async def get_model_message(client, step, obs_dict, last_reward, history, max_re
|
|
| 77 |
{"role": "system", "content": system_prompt},
|
| 78 |
{"role": "user", "content": prompt}
|
| 79 |
],
|
| 80 |
-
temperature=0.0
|
|
|
|
| 81 |
)
|
| 82 |
content = response.choices[0].message.content
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
match = re.search(r'(\{.*\})', content, re.DOTALL)
|
| 87 |
if match:
|
| 88 |
return json.loads(match.group(1))
|
| 89 |
else:
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
import asyncio
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
from openai import AsyncOpenAI
|
| 7 |
|
| 8 |
# OpenEnv V5 specific client components
|
|
|
|
| 21 |
MAX_STEPS = 15
|
| 22 |
MAX_TOTAL_REWARD = 1.0
|
| 23 |
SUCCESS_SCORE_THRESHOLD = 0.5
|
| 24 |
+
MAX_HISTORY_ITEMS = int(os.environ.get("MAX_HISTORY_ITEMS", "6"))
|
| 25 |
|
| 26 |
system_prompt = """\
|
| 27 |
SYSTEM INSTRUCTIONS: ELITE DATA ENGINEER AGENT
|
|
|
|
| 66 |
|
| 67 |
async def get_model_message(client, step, obs_dict, last_reward, history, max_retries=3):
|
| 68 |
obs_text = str(obs_dict)
|
| 69 |
+
trimmed_history = history[-MAX_HISTORY_ITEMS:] if history else []
|
| 70 |
+
prompt = (
|
| 71 |
+
f"Step {step}.\nObservation: {obs_text}\nLast Reward: {last_reward}\n"
|
| 72 |
+
f"History: {trimmed_history}\nChoose your next action (JSON matching schema)."
|
| 73 |
+
)
|
| 74 |
|
| 75 |
# Priority 3: Error Reflection. Pass previous feedback directly to LLM if there was an error.
|
| 76 |
if "Error" in obs_dict.get("last_action_feedback", "") or "Exception" in obs_dict.get("last_action_feedback", ""):
|
|
|
|
| 84 |
{"role": "system", "content": system_prompt},
|
| 85 |
{"role": "user", "content": prompt}
|
| 86 |
],
|
| 87 |
+
temperature=0.0,
|
| 88 |
+
max_tokens=220,
|
| 89 |
)
|
| 90 |
content = response.choices[0].message.content
|
| 91 |
+
|
| 92 |
+
match = re.search(r'(\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})', content or "", re.DOTALL)
|
|
|
|
|
|
|
| 93 |
if match:
|
| 94 |
return json.loads(match.group(1))
|
| 95 |
else:
|
server/app.py
CHANGED
|
@@ -28,6 +28,8 @@ Usage:
|
|
| 28 |
python -m server.app
|
| 29 |
"""
|
| 30 |
|
|
|
|
|
|
|
| 31 |
try:
|
| 32 |
from openenv.core.env_server.http_server import create_app
|
| 33 |
except Exception as e: # pragma: no cover
|
|
@@ -52,7 +54,7 @@ app = create_app(
|
|
| 52 |
DataWranglerAction,
|
| 53 |
DataWranglerObservation,
|
| 54 |
env_name="data_wrangler",
|
| 55 |
-
max_concurrent_envs=
|
| 56 |
)
|
| 57 |
|
| 58 |
|
|
|
|
| 28 |
python -m server.app
|
| 29 |
"""
|
| 30 |
|
| 31 |
+
import os
|
| 32 |
+
|
| 33 |
try:
|
| 34 |
from openenv.core.env_server.http_server import create_app
|
| 35 |
except Exception as e: # pragma: no cover
|
|
|
|
| 54 |
DataWranglerAction,
|
| 55 |
DataWranglerObservation,
|
| 56 |
env_name="data_wrangler",
|
| 57 |
+
max_concurrent_envs=int(os.getenv("MAX_CONCURRENT_ENVS", "4")),
|
| 58 |
)
|
| 59 |
|
| 60 |
|
server/data_wrangler_environment.py
CHANGED
|
@@ -109,10 +109,11 @@ class DataWranglerEnvironment(Environment):
|
|
| 109 |
def _get_obs(self, feedback: str = "Environment initialized.", done: bool = False, reward: float = 0.0) -> DataWranglerObservation:
|
| 110 |
stats = {}
|
| 111 |
for col in self.df.columns:
|
|
|
|
| 112 |
stats[col] = {
|
| 113 |
"dtype": str(self.df[col].dtype),
|
| 114 |
"missing_count": int(self.df[col].isna().sum()),
|
| 115 |
-
"sample_values":
|
| 116 |
}
|
| 117 |
|
| 118 |
return DataWranglerObservation(
|
|
@@ -132,6 +133,12 @@ class DataWranglerEnvironment(Environment):
|
|
| 132 |
self._initialize_task()
|
| 133 |
return self._get_obs()
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
def step(self, action: DataWranglerAction) -> DataWranglerObservation: # type: ignore
|
| 136 |
self._state.step_count += 1
|
| 137 |
feedback = "Action executed successfully."
|
|
@@ -141,7 +148,8 @@ class DataWranglerEnvironment(Environment):
|
|
| 141 |
try:
|
| 142 |
if action.action_type == "drop_column":
|
| 143 |
col = action.target_column
|
| 144 |
-
|
|
|
|
| 145 |
self.df.drop(columns=[col], inplace=True)
|
| 146 |
if col not in self.target_df.columns:
|
| 147 |
reward = 0.2
|
|
@@ -149,72 +157,81 @@ class DataWranglerEnvironment(Environment):
|
|
| 149 |
reward = -0.5
|
| 150 |
feedback = f"Warning: dropped targeting column {col}"
|
| 151 |
else:
|
| 152 |
-
feedback =
|
| 153 |
|
| 154 |
elif action.action_type == "rename_column":
|
| 155 |
col = action.target_column
|
| 156 |
new_col = action.new_name
|
| 157 |
-
|
|
|
|
| 158 |
self.df.rename(columns={col: new_col}, inplace=True)
|
| 159 |
if new_col in self.target_df.columns:
|
| 160 |
reward = 0.2
|
| 161 |
else:
|
| 162 |
-
feedback =
|
| 163 |
|
| 164 |
elif action.action_type == "fill_missing":
|
| 165 |
col = action.target_column
|
| 166 |
-
|
| 167 |
-
|
|
|
|
| 168 |
reward = 0.1
|
| 169 |
else:
|
| 170 |
-
feedback =
|
| 171 |
|
| 172 |
elif action.action_type == "cast_type":
|
| 173 |
col = action.target_column
|
| 174 |
-
to_type = action.cast_to
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
reward = 0.2
|
| 185 |
else:
|
| 186 |
-
feedback =
|
| 187 |
|
| 188 |
elif action.action_type == "extract_regex":
|
| 189 |
col = action.target_column
|
| 190 |
new_col = action.new_name
|
| 191 |
pattern = action.regex_pattern
|
| 192 |
-
|
|
|
|
| 193 |
# extract the first capture group
|
| 194 |
extracted = self.df[col].astype(str).str.extract(pattern)[0]
|
| 195 |
self.df[new_col] = extracted
|
| 196 |
reward = 0.1
|
| 197 |
else:
|
| 198 |
-
feedback =
|
| 199 |
|
| 200 |
elif action.action_type == "datetime_parse":
|
| 201 |
col = action.target_column
|
| 202 |
fmt = action.format_string
|
| 203 |
-
|
| 204 |
-
|
|
|
|
| 205 |
reward = 0.1
|
| 206 |
else:
|
| 207 |
-
feedback =
|
| 208 |
|
| 209 |
elif action.action_type == "group_by_aggregate":
|
| 210 |
group_col = action.target_column
|
| 211 |
agg_col = action.agg_column
|
| 212 |
func = action.agg_func
|
| 213 |
-
|
| 214 |
-
|
|
|
|
| 215 |
reward = 0.2
|
| 216 |
else:
|
| 217 |
-
feedback =
|
| 218 |
|
| 219 |
elif action.action_type == "submit":
|
| 220 |
score = self._grade()
|
|
|
|
| 109 |
def _get_obs(self, feedback: str = "Environment initialized.", done: bool = False, reward: float = 0.0) -> DataWranglerObservation:
|
| 110 |
stats = {}
|
| 111 |
for col in self.df.columns:
|
| 112 |
+
non_null = self.df[col].dropna()
|
| 113 |
stats[col] = {
|
| 114 |
"dtype": str(self.df[col].dtype),
|
| 115 |
"missing_count": int(self.df[col].isna().sum()),
|
| 116 |
+
"sample_values": non_null.astype(str).head(3).tolist(),
|
| 117 |
}
|
| 118 |
|
| 119 |
return DataWranglerObservation(
|
|
|
|
| 133 |
self._initialize_task()
|
| 134 |
return self._get_obs()
|
| 135 |
|
| 136 |
+
def _require_columns(self, *columns: str) -> str | None:
|
| 137 |
+
missing = [col for col in columns if not col or col not in self.df.columns]
|
| 138 |
+
if missing:
|
| 139 |
+
return f"Error: Column(s) not found: {', '.join(missing)}"
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
def step(self, action: DataWranglerAction) -> DataWranglerObservation: # type: ignore
|
| 143 |
self._state.step_count += 1
|
| 144 |
feedback = "Action executed successfully."
|
|
|
|
| 148 |
try:
|
| 149 |
if action.action_type == "drop_column":
|
| 150 |
col = action.target_column
|
| 151 |
+
err = self._require_columns(col)
|
| 152 |
+
if not err:
|
| 153 |
self.df.drop(columns=[col], inplace=True)
|
| 154 |
if col not in self.target_df.columns:
|
| 155 |
reward = 0.2
|
|
|
|
| 157 |
reward = -0.5
|
| 158 |
feedback = f"Warning: dropped targeting column {col}"
|
| 159 |
else:
|
| 160 |
+
feedback = err
|
| 161 |
|
| 162 |
elif action.action_type == "rename_column":
|
| 163 |
col = action.target_column
|
| 164 |
new_col = action.new_name
|
| 165 |
+
err = self._require_columns(col)
|
| 166 |
+
if not err:
|
| 167 |
self.df.rename(columns={col: new_col}, inplace=True)
|
| 168 |
if new_col in self.target_df.columns:
|
| 169 |
reward = 0.2
|
| 170 |
else:
|
| 171 |
+
feedback = err
|
| 172 |
|
| 173 |
elif action.action_type == "fill_missing":
|
| 174 |
col = action.target_column
|
| 175 |
+
err = self._require_columns(col)
|
| 176 |
+
if not err:
|
| 177 |
+
self.df[col] = self.df[col].fillna(action.fill_value)
|
| 178 |
reward = 0.1
|
| 179 |
else:
|
| 180 |
+
feedback = err
|
| 181 |
|
| 182 |
elif action.action_type == "cast_type":
|
| 183 |
col = action.target_column
|
| 184 |
+
to_type = (action.cast_to or "").lower()
|
| 185 |
+
err = self._require_columns(col)
|
| 186 |
+
if not err:
|
| 187 |
+
if to_type == "int":
|
| 188 |
+
self.df[col] = pd.to_numeric(self.df[col], errors="coerce").astype("Int64")
|
| 189 |
+
elif to_type == "float":
|
| 190 |
+
self.df[col] = pd.to_numeric(self.df[col], errors="coerce").astype(float)
|
| 191 |
+
elif to_type == "datetime":
|
| 192 |
+
self.df[col] = pd.to_datetime(self.df[col], errors="coerce")
|
| 193 |
+
elif to_type == "string":
|
| 194 |
+
self.df[col] = self.df[col].astype(str)
|
| 195 |
+
else:
|
| 196 |
+
feedback = f"Error: Unsupported cast type '{action.cast_to}'."
|
| 197 |
+
return self._get_obs(feedback=feedback, done=done, reward=reward)
|
| 198 |
reward = 0.2
|
| 199 |
else:
|
| 200 |
+
feedback = err
|
| 201 |
|
| 202 |
elif action.action_type == "extract_regex":
|
| 203 |
col = action.target_column
|
| 204 |
new_col = action.new_name
|
| 205 |
pattern = action.regex_pattern
|
| 206 |
+
err = self._require_columns(col)
|
| 207 |
+
if not err and new_col and pattern:
|
| 208 |
# extract the first capture group
|
| 209 |
extracted = self.df[col].astype(str).str.extract(pattern)[0]
|
| 210 |
self.df[new_col] = extracted
|
| 211 |
reward = 0.1
|
| 212 |
else:
|
| 213 |
+
feedback = err or "Error: 'new_name' and 'regex_pattern' are required."
|
| 214 |
|
| 215 |
elif action.action_type == "datetime_parse":
|
| 216 |
col = action.target_column
|
| 217 |
fmt = action.format_string
|
| 218 |
+
err = self._require_columns(col)
|
| 219 |
+
if not err:
|
| 220 |
+
self.df[col] = pd.to_datetime(self.df[col], format=fmt, errors="coerce")
|
| 221 |
reward = 0.1
|
| 222 |
else:
|
| 223 |
+
feedback = err
|
| 224 |
|
| 225 |
elif action.action_type == "group_by_aggregate":
|
| 226 |
group_col = action.target_column
|
| 227 |
agg_col = action.agg_column
|
| 228 |
func = action.agg_func
|
| 229 |
+
err = self._require_columns(group_col, agg_col)
|
| 230 |
+
if not err and func:
|
| 231 |
+
self.df = self.df.groupby(group_col, as_index=False, observed=True).agg({agg_col: func})
|
| 232 |
reward = 0.2
|
| 233 |
else:
|
| 234 |
+
feedback = err or "Error: 'agg_func' is required."
|
| 235 |
|
| 236 |
elif action.action_type == "submit":
|
| 237 |
score = self._grade()
|