KnightBlade commited on
Commit
b15226e
·
1 Parent(s): bf11096

Align inference script with validator env vars and strict stdout format

Browse files
Files changed (2) hide show
  1. client.py +6 -5
  2. inference.py +56 -34
client.py CHANGED
@@ -54,9 +54,7 @@ class DataWranglerEnv(
54
  Returns:
55
  Dictionary representation suitable for JSON encoding
56
  """
57
- return {
58
- "message": action.message,
59
- }
60
 
61
  def _parse_result(self, payload: Dict) -> StepResult[DataWranglerObservation]:
62
  """
@@ -70,8 +68,11 @@ class DataWranglerEnv(
70
  """
71
  obs_data = payload.get("observation", {})
72
  observation = DataWranglerObservation(
73
- echoed_message=obs_data.get("echoed_message", ""),
74
- message_length=obs_data.get("message_length", 0),
 
 
 
75
  done=payload.get("done", False),
76
  reward=payload.get("reward"),
77
  metadata=obs_data.get("metadata", {}),
 
54
  Returns:
55
  Dictionary representation suitable for JSON encoding
56
  """
57
+ return action.model_dump(mode="json", exclude_none=True)
 
 
58
 
59
  def _parse_result(self, payload: Dict) -> StepResult[DataWranglerObservation]:
60
  """
 
68
  """
69
  obs_data = payload.get("observation", {})
70
  observation = DataWranglerObservation(
71
+ columns=obs_data.get("columns", []),
72
+ row_count=obs_data.get("row_count", 0),
73
+ column_stats=obs_data.get("column_stats", {}),
74
+ last_action_feedback=obs_data.get("last_action_feedback", ""),
75
+ is_done=obs_data.get("is_done", payload.get("done", False)),
76
  done=payload.get("done", False),
77
  reward=payload.get("reward"),
78
  metadata=obs_data.get("metadata", {}),
inference.py CHANGED
@@ -1,22 +1,21 @@
1
  import os
2
- import sys
3
  import asyncio
4
  import json
5
  import re
6
  from openai import AsyncOpenAI
7
 
8
- # OpenEnv V5 specific client components
9
- # We import directly since OpenEnv varies slightly in versions, but this mirrors the validator script expectations.
10
  try:
11
- from openenv.core.client import EnvClient
12
- except ImportError:
13
- pass
14
-
15
- API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
16
- API_KEY = os.environ.get("OPENAI_API_KEY", "")
17
- MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-3.5-turbo")
18
- IMAGE_NAME = "data_wrangler"
19
- TASK_NAME = "Data Writer Level 1"
 
 
20
  BENCHMARK = "data_wrangler"
21
  MAX_STEPS = 15
22
  MAX_TOTAL_REWARD = 1.0
@@ -100,30 +99,55 @@ async def get_model_message(client, step, obs_dict, last_reward, history, max_re
100
  # Fallback only if absolutely all retries fail
101
  return {"action_type": "submit"}
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def log_start(task, env, model):
104
  print(f"[START] task={task} env={env} model={model}")
105
 
 
106
  def log_step(step, action, reward, done, error):
107
- print(f"[STEP] step={step} action={action} reward={reward} done={done} error={error}")
 
 
 
 
108
 
109
- def log_end(success, steps, score, rewards):
110
- print(f"[END] success={success} steps={steps} score={score} rewards={rewards}")
 
 
111
 
112
  async def main():
113
- if not API_KEY:
114
- print("Missing OPENAI_API_KEY environment variable.")
 
 
115
  return
116
 
117
- client = AsyncOpenAI(base_url=API_BASE_URL, api_key=API_KEY)
118
-
119
- print(f"[DEBUG] Spinning up {IMAGE_NAME} environment container...", flush=True)
120
  try:
121
  from client import DataWranglerEnv
122
- env = DataWranglerEnv.from_docker_image(IMAGE_NAME)
123
- except Exception as e:
124
- print(f"[DEBUG] Docker env start failed ({e}). Falling back to local direct Python import.", flush=True)
125
  from server.data_wrangler_environment import DataWranglerEnvironment
126
- env = DataWranglerEnvironment() # Fallback for local debugging
127
 
128
  history = []
129
  rewards = []
@@ -131,8 +155,6 @@ async def main():
131
  score = 0.0
132
  success = False
133
 
134
- log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
135
-
136
  try:
137
  if hasattr(env, 'reset') and not asyncio.iscoroutinefunction(env.reset):
138
  result = env.reset()
@@ -155,15 +177,14 @@ async def main():
155
  break
156
 
157
  action_data = await get_model_message(client, step, obs_dict, last_reward, history)
158
-
159
- from models import DataWranglerAction
160
  action_obj = DataWranglerAction(**action_data)
161
-
162
  if hasattr(env, 'step') and not asyncio.iscoroutinefunction(env.step):
163
  result = env.step(action_obj)
164
  else:
165
  result = await env.step(action_obj)
166
-
167
  obs = getattr(result, "observation", result)
168
  obs_dict = {
169
  "columns": getattr(obs, "columns", []),
@@ -175,7 +196,8 @@ async def main():
175
 
176
  reward = getattr(result, "reward", getattr(obs, "reward", 0.0)) or 0.0
177
  done = getattr(result, "done", getattr(obs, "is_done", False))
178
- error = None
 
179
 
180
  rewards.append(reward)
181
  steps_taken = step
@@ -200,9 +222,9 @@ async def main():
200
  else:
201
  env.close()
202
  except Exception as e:
203
- print(f"[DEBUG] env.close() error (container cleanup): {e}", flush=True)
204
-
205
- log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
206
 
207
  if __name__ == "__main__":
208
  asyncio.run(main())
 
1
  import os
 
2
  import asyncio
3
  import json
4
  import re
5
  from openai import AsyncOpenAI
6
 
 
 
7
  try:
8
+ from models import DataWranglerAction
9
+ except (ImportError, ModuleNotFoundError):
10
+ import sys
11
+ sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
12
+ from models import DataWranglerAction
13
+
14
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
15
+ MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
16
+ HF_TOKEN = os.getenv("HF_TOKEN")
17
+ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "data_wrangler")
18
+ TASK_NAME = "data_wrangler_task"
19
  BENCHMARK = "data_wrangler"
20
  MAX_STEPS = 15
21
  MAX_TOTAL_REWARD = 1.0
 
99
  # Fallback only if absolutely all retries fail
100
  return {"action_type": "submit"}
101
 
102
+ def _bool_str(value):
103
+ return "true" if bool(value) else "false"
104
+
105
+
106
+ def _action_str(action):
107
+ try:
108
+ return json.dumps(action, separators=(",", ":"), ensure_ascii=False)
109
+ except Exception:
110
+ return str(action).replace("\n", " ")
111
+
112
+
113
+ def _reward_str(value):
114
+ try:
115
+ return f"{float(value):.2f}"
116
+ except Exception:
117
+ return "0.00"
118
+
119
+
120
  def log_start(task, env, model):
121
  print(f"[START] task={task} env={env} model={model}")
122
 
123
+
124
  def log_step(step, action, reward, done, error):
125
+ error_str = "null" if error is None else str(error).replace("\n", " ")
126
+ print(
127
+ f"[STEP] step={step} action={_action_str(action)} "
128
+ f"reward={_reward_str(reward)} done={_bool_str(done)} error={error_str}"
129
+ )
130
 
131
+
132
+ def log_end(success, steps, rewards):
133
+ rewards_csv = ",".join(_reward_str(r) for r in rewards)
134
+ print(f"[END] success={_bool_str(success)} steps={steps} rewards={rewards_csv}")
135
 
136
  async def main():
137
+ log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
138
+
139
+ if not HF_TOKEN:
140
+ log_end(success=False, steps=0, rewards=[])
141
  return
142
 
143
+ client = AsyncOpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
144
+
 
145
  try:
146
  from client import DataWranglerEnv
147
+ env = DataWranglerEnv.from_docker_image(LOCAL_IMAGE_NAME)
148
+ except Exception:
 
149
  from server.data_wrangler_environment import DataWranglerEnvironment
150
+ env = DataWranglerEnvironment()
151
 
152
  history = []
153
  rewards = []
 
155
  score = 0.0
156
  success = False
157
 
 
 
158
  try:
159
  if hasattr(env, 'reset') and not asyncio.iscoroutinefunction(env.reset):
160
  result = env.reset()
 
177
  break
178
 
179
  action_data = await get_model_message(client, step, obs_dict, last_reward, history)
180
+
 
181
  action_obj = DataWranglerAction(**action_data)
182
+
183
  if hasattr(env, 'step') and not asyncio.iscoroutinefunction(env.step):
184
  result = env.step(action_obj)
185
  else:
186
  result = await env.step(action_obj)
187
+
188
  obs = getattr(result, "observation", result)
189
  obs_dict = {
190
  "columns": getattr(obs, "columns", []),
 
196
 
197
  reward = getattr(result, "reward", getattr(obs, "reward", 0.0)) or 0.0
198
  done = getattr(result, "done", getattr(obs, "is_done", False))
199
+ feedback = obs_dict.get("last_action_feedback", "")
200
+ error = feedback if ("Error" in feedback or "Exception" in feedback) else None
201
 
202
  rewards.append(reward)
203
  steps_taken = step
 
222
  else:
223
  env.close()
224
  except Exception as e:
225
+ _ = e
226
+
227
+ log_end(success=success, steps=steps_taken, rewards=rewards)
228
 
229
  if __name__ == "__main__":
230
  asyncio.run(main())