Prithvigg commited on
Commit
3502e46
Β·
verified Β·
1 Parent(s): d151777

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +59 -44
inference.py CHANGED
@@ -33,7 +33,7 @@ from models import SQLAction
33
 
34
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
35
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
36
- MODEL_NAME = os.getenv("MODEL_NAME")
37
  ENV_URL = os.getenv("ENV_URL", "https://prithvigg-queryforge.hf.space")
38
 
39
  MAX_STEPS = 5
@@ -115,34 +115,42 @@ def hr(char="═", width=70):
115
  # ── Per-task agent loop ────────────────────────────────────────────────────────
116
 
117
  def run_task(task_id: str, llm: OpenAI, env_client) -> dict:
118
- result = env_client.reset(task_id=task_id)
119
- obs = result.observation
 
 
 
 
 
 
 
120
 
121
  log_start(task=task_id, model=MODEL_NAME)
122
 
123
- if result.done:
124
- print(f" ERROR loading task: {obs.feedback}")
125
- log_end(success=False, steps=0, score=0.0, rewards=[])
126
- return {"task_id": task_id, "best_score": 0.0, "attempts": 0, "done": False}
127
-
128
- print(f"\n Task : {obs.task_title} [{obs.task_level}]")
129
-
130
- messages = [
131
- {"role": "system", "content": SYSTEM_PROMPT},
132
- {
133
- "role": "user",
134
- "content": (
135
- f"Here is your SQL challenge:\n\n{obs.task_description}\n\n"
136
- "Provide your fixed SQL query."
137
- ),
138
- },
139
- ]
140
-
141
- step = 0
142
- rewards: List[float] = []
143
- success = False
144
-
145
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  while not result.done:
147
  step += 1
148
 
@@ -171,10 +179,10 @@ def run_task(task_id: str, llm: OpenAI, env_client) -> dict:
171
 
172
  reward = result.reward or 0.0
173
  rewards.append(reward)
 
174
 
175
- # Determine error string for [STEP] log
176
  if not obs.syntax_valid:
177
- step_error = "syntax_error"
178
  print(f" βœ— Syntax error β€” query could not be parsed")
179
  elif not obs.execution_success:
180
  step_error = (obs.execution_error or "execution_error")[:120]
@@ -183,12 +191,12 @@ def run_task(task_id: str, llm: OpenAI, env_client) -> dict:
183
  step_error = None
184
  print(f" βœ“ Executed Β· rows returned: {obs.rows_returned}")
185
 
186
- done_marker = " βœ“ DONE" if result.done else ""
187
  print(f" Score : {score_bar(reward)}{done_marker}")
188
 
189
- log_step(step=step, action=sql, reward=reward, done=result.done, error=step_error)
190
 
191
- if result.done:
192
  break
193
 
194
  print(f"\n ↻ Retrying β€” score {reward:.3f} below threshold")
@@ -211,27 +219,29 @@ def run_task(task_id: str, llm: OpenAI, env_client) -> dict:
211
  ),
212
  })
213
 
214
- success = obs.best_score >= SUCCESS_SCORE_THRESHOLD
 
 
 
 
 
215
 
216
  finally:
217
- log_end(success=success, steps=step, score=obs.best_score, rewards=rewards)
218
 
219
  return {
220
  "task_id": task_id,
221
- "task_title": obs.task_title,
222
- "task_level": obs.task_level,
223
- "best_score": obs.best_score,
224
- "attempts": obs.attempt,
225
- "done": result.done,
226
  }
227
 
228
 
229
  # ── Main ───────────────────────────────────────────────────────────────────────
230
 
231
  def main() -> None:
232
- if not MODEL_NAME:
233
- print("ERROR: MODEL_NAME env var is not set.")
234
- sys.exit(1)
235
  if not API_KEY:
236
  print("ERROR: HF_TOKEN (or API_KEY) is not set.")
237
  sys.exit(1)
@@ -246,10 +256,15 @@ def main() -> None:
246
  hr()
247
 
248
  results = []
249
- with QueryforgeEnv(base_url=ENV_URL).sync() as env_client:
250
- for task_id in TASK_IDS:
251
- print(f"\n{'─' * 70}")
252
- results.append(run_task(task_id, llm, env_client))
 
 
 
 
 
253
 
254
  # ── Results summary ───────────────────────────────────────────────────────
255
  print(f"\n{'═' * 70}")
 
33
 
34
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
35
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
36
+ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
37
  ENV_URL = os.getenv("ENV_URL", "https://prithvigg-queryforge.hf.space")
38
 
39
  MAX_STEPS = 5
 
115
  # ── Per-task agent loop ────────────────────────────────────────────────────────
116
 
117
  def run_task(task_id: str, llm: OpenAI, env_client) -> dict:
118
+ # Initialise before anything that can throw β€” guarantees [END] is always emitted.
119
+ step = 0
120
+ rewards: List[float] = []
121
+ success = False
122
+ best_score = 0.0
123
+ task_title = task_id
124
+ task_level = "unknown"
125
+ attempts = 0
126
+ done = False
127
 
128
  log_start(task=task_id, model=MODEL_NAME)
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  try:
131
+ result = env_client.reset(task_id=task_id)
132
+ obs = result.observation
133
+
134
+ if result.done:
135
+ print(f" ERROR loading task: {obs.feedback}")
136
+ log_end(success=False, steps=0, score=0.0, rewards=[])
137
+ return {"task_id": task_id, "best_score": 0.0, "attempts": 0, "done": False}
138
+
139
+ task_title = obs.task_title
140
+ task_level = obs.task_level
141
+ print(f"\n Task : {task_title} [{task_level}]")
142
+
143
+ messages = [
144
+ {"role": "system", "content": SYSTEM_PROMPT},
145
+ {
146
+ "role": "user",
147
+ "content": (
148
+ f"Here is your SQL challenge:\n\n{obs.task_description}\n\n"
149
+ "Provide your fixed SQL query."
150
+ ),
151
+ },
152
+ ]
153
+
154
  while not result.done:
155
  step += 1
156
 
 
179
 
180
  reward = result.reward or 0.0
181
  rewards.append(reward)
182
+ done = result.done
183
 
 
184
  if not obs.syntax_valid:
185
+ step_error: Optional[str] = "syntax_error"
186
  print(f" βœ— Syntax error β€” query could not be parsed")
187
  elif not obs.execution_success:
188
  step_error = (obs.execution_error or "execution_error")[:120]
 
191
  step_error = None
192
  print(f" βœ“ Executed Β· rows returned: {obs.rows_returned}")
193
 
194
+ done_marker = " βœ“ DONE" if done else ""
195
  print(f" Score : {score_bar(reward)}{done_marker}")
196
 
197
+ log_step(step=step, action=sql, reward=reward, done=done, error=step_error)
198
 
199
+ if done:
200
  break
201
 
202
  print(f"\n ↻ Retrying β€” score {reward:.3f} below threshold")
 
219
  ),
220
  })
221
 
222
+ best_score = obs.best_score
223
+ attempts = obs.attempt
224
+ success = best_score >= SUCCESS_SCORE_THRESHOLD
225
+
226
+ except Exception as exc:
227
+ print(f" FATAL error in task {task_id}: {exc}", flush=True)
228
 
229
  finally:
230
+ log_end(success=success, steps=step, score=best_score, rewards=rewards)
231
 
232
  return {
233
  "task_id": task_id,
234
+ "task_title": task_title,
235
+ "task_level": task_level,
236
+ "best_score": best_score,
237
+ "attempts": attempts,
238
+ "done": done,
239
  }
240
 
241
 
242
  # ── Main ───────────────────────────────────────────────────────────────────────
243
 
244
  def main() -> None:
 
 
 
245
  if not API_KEY:
246
  print("ERROR: HF_TOKEN (or API_KEY) is not set.")
247
  sys.exit(1)
 
256
  hr()
257
 
258
  results = []
259
+ try:
260
+ env_ctx = QueryforgeEnv(base_url=ENV_URL).sync()
261
+ with env_ctx as env_client:
262
+ for task_id in TASK_IDS:
263
+ print(f"\n{'─' * 70}")
264
+ results.append(run_task(task_id, llm, env_client))
265
+ except Exception as exc:
266
+ print(f"FATAL: could not connect to environment at {ENV_URL}: {exc}", flush=True)
267
+ sys.exit(1)
268
 
269
  # ── Results summary ───────────────────────────────────────────────────────
270
  print(f"\n{'═' * 70}")