srishtichugh commited on
Commit
d4930ce
·
1 Parent(s): 778e7e1

Fix openenv validate: add main(), entry point, openenv-core dep, uv.lock

Browse files
Files changed (6) hide show
  1. inference.py +128 -55
  2. inference_log.txt +0 -0
  3. pyproject.toml +5 -1
  4. requirements.txt +1 -0
  5. server/app.py +13 -0
  6. uv.lock +0 -0
inference.py CHANGED
@@ -7,12 +7,19 @@ Required environment variables:
7
  MODEL_NAME — model identifier
8
  HF_TOKEN — API key
9
  ENV_URL — environment server URL (default: http://localhost:8000)
 
 
 
 
 
10
  """
11
 
12
  import json
13
  import os
 
14
  import sys
15
  import time
 
16
  import httpx
17
  from openai import OpenAI
18
 
@@ -66,19 +73,41 @@ Rules:
66
  """
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  # ------------------------------------------------------------------
70
  # HTTP helpers
71
  # ------------------------------------------------------------------
72
 
73
  def api_post(path: str, payload: dict = None) -> dict:
74
- url = ENV_URL.rstrip("/") + path
75
  resp = httpx.post(url, json=payload or {}, timeout=30)
76
  resp.raise_for_status()
77
  return resp.json()
78
 
79
 
80
  def api_get(path: str) -> dict:
81
- url = ENV_URL.rstrip("/") + path
82
  resp = httpx.get(url, timeout=10)
83
  resp.raise_for_status()
84
  return resp.json()
@@ -108,58 +137,102 @@ def obs_to_text(obs: dict) -> str:
108
 
109
 
110
  def run_task(task_id: int) -> float:
111
- print(f"\n{'='*60}")
112
- print(f" Running Task {task_id}")
113
- print(f"{'='*60}")
 
 
 
114
 
115
  result = api_post("/reset", {"task_id": task_id})
116
  obs = result["observation"]
117
  history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- for step_num in range(1, 50):
120
- if obs["done"]:
121
- break
122
-
123
- obs_text = obs_to_text(obs)
124
- history.append({"role": "user", "content": obs_text})
125
-
126
- response = client.chat.completions.create(
127
- model = MODEL_NAME,
128
- messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history,
129
- temperature = 0.0,
130
- max_tokens = 256,
131
- )
132
- action_str = response.choices[0].message.content.strip()
133
- history.append({"role": "assistant", "content": action_str})
134
-
135
- # Parse action
136
- try:
137
- action = json.loads(action_str)
138
- except json.JSONDecodeError:
139
- # Try to extract JSON from markdown code fence
140
- import re
141
- m = re.search(r"\{.*\}", action_str, re.DOTALL)
142
- if m:
143
- try:
144
- action = json.loads(m.group())
145
- except Exception:
146
- print(f" Step {step_num}: Could not parse action JSON, skipping.")
147
- break
148
- else:
149
- print(f" Step {step_num}: No JSON found in response, skipping.")
150
  break
151
 
152
- print(f" Step {step_num:2d} | score={obs['current_score']:.4f} | action={json.dumps(action)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- result = api_post("/step", action)
155
- obs = result["observation"]
156
- print(f" → {obs['message']}")
157
 
158
- # Slight delay to stay within rate limits on free-tier endpoints
159
- time.sleep(0.3)
160
 
161
  final_score = obs["current_score"]
162
- print(f"\n Task {task_id} final score: {final_score:.4f} (steps used: {obs['step_count']})")
 
 
 
163
  return final_score
164
 
165
 
@@ -168,33 +241,33 @@ def run_task(task_id: int) -> float:
168
  # ------------------------------------------------------------------
169
 
170
  def main():
171
- print("Data Cleaning OpenEnv Baseline Inference")
172
- print(f"Model : {MODEL_NAME}")
173
- print(f"Env : {ENV_URL}")
174
 
175
  # Smoke-test health endpoint
176
  health = api_get("/health")
177
  assert health.get("status") == "ok", f"Health check failed: {health}"
178
- print("Health check: OK\n")
179
 
180
  scores = {}
181
  for task_id in [1, 2, 3]:
182
  scores[f"task{task_id}"] = run_task(task_id)
183
 
184
- print("\n" + "="*60)
185
- print(" BASELINE RESULTS")
186
- print("="*60)
187
  for k, v in scores.items():
188
- print(f" {k}: {v:.4f}")
189
  avg = sum(scores.values()) / len(scores)
190
- print(f" average: {avg:.4f}")
191
- print("="*60)
192
 
193
  # Write scores to file for automated validators
194
  with open("baseline_scores.json", "w") as f:
195
  json.dump({"scores": scores, "average": avg}, f, indent=2)
196
- print("\nScores written to baseline_scores.json")
197
 
198
 
199
  if __name__ == "__main__":
200
- main()
 
7
  MODEL_NAME — model identifier
8
  HF_TOKEN — API key
9
  ENV_URL — environment server URL (default: http://localhost:8000)
10
+
11
+ STDOUT FORMAT (OpenEnv spec):
12
+ [START] task=<task_name> env=<benchmark> model=<model_name>
13
+ [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
14
+ [END] success=<true|false> steps=<n> rewards=<r1,r2,...,rn>
15
  """
16
 
17
  import json
18
  import os
19
+ import re
20
  import sys
21
  import time
22
+ from typing import List, Optional
23
  import httpx
24
  from openai import OpenAI
25
 
 
73
  """
74
 
75
 
76
+ # ------------------------------------------------------------------
77
+ # OpenEnv stdout logging helpers
78
+ # ------------------------------------------------------------------
79
+
80
+ def log_start(task: str, env: str, model: str) -> None:
81
+ print(f"[START] task={task} env={env} model={model}", flush=True)
82
+
83
+
84
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
85
+ error_val = error if error else "null"
86
+ done_val = str(done).lower()
87
+ print(
88
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
89
+ flush=True,
90
+ )
91
+
92
+
93
+ def log_end(success: bool, steps: int, rewards: List[float]) -> None:
94
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
95
+ print(f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}", flush=True)
96
+
97
+
98
  # ------------------------------------------------------------------
99
  # HTTP helpers
100
  # ------------------------------------------------------------------
101
 
102
  def api_post(path: str, payload: dict = None) -> dict:
103
+ url = ENV_URL.rstrip("/") + path
104
  resp = httpx.post(url, json=payload or {}, timeout=30)
105
  resp.raise_for_status()
106
  return resp.json()
107
 
108
 
109
  def api_get(path: str) -> dict:
110
+ url = ENV_URL.rstrip("/") + path
111
  resp = httpx.get(url, timeout=10)
112
  resp.raise_for_status()
113
  return resp.json()
 
137
 
138
 
139
  def run_task(task_id: int) -> float:
140
+ task_name = f"data-cleaning-task{task_id}"
141
+
142
+ # Human-readable header (stderr so it doesn't interfere with stdout format)
143
+ print(f"\n{'='*60}", file=sys.stderr)
144
+ print(f" Running Task {task_id}", file=sys.stderr)
145
+ print(f"{'='*60}", file=sys.stderr)
146
 
147
  result = api_post("/reset", {"task_id": task_id})
148
  obs = result["observation"]
149
  history = []
150
+ rewards: List[float] = []
151
+ steps_taken = 0
152
+ success = False
153
+
154
+ log_start(task=task_name, env="data-cleaning-openenv", model=MODEL_NAME)
155
+
156
+ try:
157
+ for step_num in range(1, 50):
158
+ if obs["done"]:
159
+ success = obs["current_score"] >= 0.95
160
+ break
161
+
162
+ obs_text = obs_to_text(obs)
163
+ history.append({"role": "user", "content": obs_text})
164
+
165
+ try:
166
+ response = client.chat.completions.create(
167
+ model = MODEL_NAME,
168
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history,
169
+ temperature = 0.0,
170
+ max_tokens = 256,
171
+ )
172
+ action_str = response.choices[0].message.content.strip()
173
+ except Exception as exc:
174
+ print(f" Step {step_num}: LLM call failed: {exc}", file=sys.stderr)
175
+ log_step(step_num, "null", 0.0, True, str(exc))
176
+ break
177
 
178
+ history.append({"role": "assistant", "content": action_str})
179
+
180
+ # Parse action JSON
181
+ action = None
182
+ try:
183
+ action = json.loads(action_str)
184
+ except json.JSONDecodeError:
185
+ m = re.search(r"\{.*\}", action_str, re.DOTALL)
186
+ if m:
187
+ try:
188
+ action = json.loads(m.group())
189
+ except Exception:
190
+ pass
191
+
192
+ if action is None:
193
+ print(f" Step {step_num}: Could not parse action JSON, skipping.", file=sys.stderr)
194
+ log_step(step_num, action_str, -0.05, False, "json_parse_error")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  break
196
 
197
+ action_label = json.dumps(action, separators=(",", ":"))
198
+ print(
199
+ f" Step {step_num:2d} | score={obs['current_score']:.4f} | action={action_label}",
200
+ file=sys.stderr,
201
+ )
202
+
203
+ result = api_post("/step", action)
204
+ obs = result["observation"]
205
+ step_reward = result["reward"]
206
+ done = result["done"]
207
+ error_msg = None if obs["message"].startswith("Fill") or step_reward >= 0 else obs["message"]
208
+
209
+ print(f" -> {obs['message']}", file=sys.stderr)
210
+
211
+ rewards.append(step_reward)
212
+ steps_taken = step_num
213
+
214
+ log_step(
215
+ step = step_num,
216
+ action = action_label,
217
+ reward = step_reward,
218
+ done = done,
219
+ error = error_msg,
220
+ )
221
+
222
+ if done:
223
+ success = obs["current_score"] >= 0.95
224
+ break
225
 
226
+ time.sleep(0.3)
 
 
227
 
228
+ finally:
229
+ log_end(success=success, steps=steps_taken, rewards=rewards)
230
 
231
  final_score = obs["current_score"]
232
+ print(
233
+ f"\n Task {task_id} final score: {final_score:.4f} (steps used: {obs['step_count']})",
234
+ file=sys.stderr,
235
+ )
236
  return final_score
237
 
238
 
 
241
  # ------------------------------------------------------------------
242
 
243
  def main():
244
+ print("Data Cleaning OpenEnv -- Baseline Inference", file=sys.stderr)
245
+ print(f"Model : {MODEL_NAME}", file=sys.stderr)
246
+ print(f"Env : {ENV_URL}", file=sys.stderr)
247
 
248
  # Smoke-test health endpoint
249
  health = api_get("/health")
250
  assert health.get("status") == "ok", f"Health check failed: {health}"
251
+ print("Health check: OK\n", file=sys.stderr)
252
 
253
  scores = {}
254
  for task_id in [1, 2, 3]:
255
  scores[f"task{task_id}"] = run_task(task_id)
256
 
257
+ print("\n" + "="*60, file=sys.stderr)
258
+ print(" BASELINE RESULTS", file=sys.stderr)
259
+ print("="*60, file=sys.stderr)
260
  for k, v in scores.items():
261
+ print(f" {k}: {v:.4f}", file=sys.stderr)
262
  avg = sum(scores.values()) / len(scores)
263
+ print(f" average: {avg:.4f}", file=sys.stderr)
264
+ print("="*60, file=sys.stderr)
265
 
266
  # Write scores to file for automated validators
267
  with open("baseline_scores.json", "w") as f:
268
  json.dump({"scores": scores, "average": avg}, f, indent=2)
269
+ print("\nScores written to baseline_scores.json", file=sys.stderr)
270
 
271
 
272
  if __name__ == "__main__":
273
+ main()
inference_log.txt CHANGED
Binary files a/inference_log.txt and b/inference_log.txt differ
 
pyproject.toml CHANGED
@@ -12,11 +12,15 @@ dependencies = [
12
  "faker>=18.0.0",
13
  "openai>=1.0.0",
14
  "httpx>=0.25.0",
 
15
  ]
16
 
 
 
 
17
  [build-system]
18
  requires = ["hatchling"]
19
  build-backend = "hatchling.build"
20
 
21
  [tool.hatch.build.targets.wheel]
22
- packages = ["server"]
 
12
  "faker>=18.0.0",
13
  "openai>=1.0.0",
14
  "httpx>=0.25.0",
15
+ "openenv-core>=0.2.0",
16
  ]
17
 
18
+ [project.scripts]
19
+ serve = "server.app:main"
20
+
21
  [build-system]
22
  requires = ["hatchling"]
23
  build-backend = "hatchling.build"
24
 
25
  [tool.hatch.build.targets.wheel]
26
+ packages = ["server"]
requirements.txt CHANGED
@@ -6,3 +6,4 @@ numpy>=1.24.0
6
  faker>=18.0.0
7
  openai>=1.0.0
8
  httpx>=0.25.0
 
 
6
  faker>=18.0.0
7
  openai>=1.0.0
8
  httpx>=0.25.0
9
+ openenv-core>=0.2.0
server/app.py CHANGED
@@ -6,6 +6,7 @@ Endpoints: GET /health, POST /reset, POST /step, POST /state, GET /docs
6
  from typing import Optional
7
  from fastapi import FastAPI, HTTPException
8
  from pydantic import BaseModel
 
9
 
10
  from models import DataCleaningAction, DataCleaningObservation, DataCleaningState
11
  from server.environment import DataCleaningEnvironment
@@ -61,3 +62,15 @@ def step(action: DataCleaningAction):
61
  @app.post("/state", response_model=DataCleaningState)
62
  def state():
63
  return env.state()
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from typing import Optional
7
  from fastapi import FastAPI, HTTPException
8
  from pydantic import BaseModel
9
+ import uvicorn
10
 
11
  from models import DataCleaningAction, DataCleaningObservation, DataCleaningState
12
  from server.environment import DataCleaningEnvironment
 
62
  @app.post("/state", response_model=DataCleaningState)
63
  def state():
64
  return env.state()
65
+
66
+
67
+ # ------------------------------------------------------------------
68
+ # Entry point (required by openenv-core and [project.scripts])
69
+ # ------------------------------------------------------------------
70
+
71
+ def main():
72
+ uvicorn.run("server.app:app", host="0.0.0.0", port=8000)
73
+
74
+
75
+ if __name__ == "__main__":
76
+ main()
uv.lock ADDED
The diff for this file is too large to render. See raw diff