Pratap-K commited on
Commit
58241d0
Β·
1 Parent(s): 607c14d
Files changed (1) hide show
  1. inference.py +26 -7
inference.py CHANGED
@@ -361,21 +361,23 @@ def env_state() -> dict:
361
 
362
  # ─── Main ─────────────────────────────────────────────────────────────────────
363
 
364
- def main() -> None:
365
- client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
366
-
367
  rewards: list[float] = []
368
  steps_taken = 0
369
  score = 0.0
370
  success = False
371
 
372
- log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
 
 
 
373
 
374
  try:
375
  # Reset environment
376
- obs = env_reset(task=TASK_NAME, seed=SEED)
377
 
378
- for step in range(1, MAX_STEPS + 1):
379
  if obs.get("content_item", {}).get("content_id") == "__terminal__":
380
  break
381
 
@@ -421,12 +423,29 @@ def main() -> None:
421
  success = score >= SUCCESS_SCORE_THRESHOLD
422
 
423
  except Exception as e:
424
- print(f"[DEBUG] Fatal error: {e}\n{traceback.format_exc()}", flush=True)
425
  success = False
426
 
427
  finally:
428
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
429
 
430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  if __name__ == "__main__":
432
  main()
 
361
 
362
  # ─── Main ─────────────────────────────────────────────────────────────────────
363
 
364
+ def run_task(client: OpenAI, task_name: str, seed: int) -> None:
365
+ """Run inference for a specific task and log results."""
 
366
  rewards: list[float] = []
367
  steps_taken = 0
368
  score = 0.0
369
  success = False
370
 
371
+ # Get max steps for this specific task
372
+ max_steps = int(os.getenv("MAX_STEPS_OVERRIDE", str(TASK_MAX_STEPS.get(task_name, 10))))
373
+
374
+ log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
375
 
376
  try:
377
  # Reset environment
378
+ obs = env_reset(task=task_name, seed=seed)
379
 
380
+ for step in range(1, max_steps + 1):
381
  if obs.get("content_item", {}).get("content_id") == "__terminal__":
382
  break
383
 
 
423
  success = score >= SUCCESS_SCORE_THRESHOLD
424
 
425
  except Exception as e:
426
+ print(f"[DEBUG] Fatal error in task {task_name}: {e}\n{traceback.format_exc()}", flush=True)
427
  success = False
428
 
429
  finally:
430
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
431
 
432
 
433
+ def main() -> None:
434
+ client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
435
+
436
+ # List of tasks to iterate through
437
+ tasks_to_run = [
438
+ "single-label-classify",
439
+ "multi-label-classify",
440
+ "ad-policy-compliance",
441
+ "thread-moderation-hard"
442
+ ]
443
+
444
+ # If MODERATION_TASK is set and valid, we could prioritize it,
445
+ # but the requirement is to iterate through all.
446
+ for task in tasks_to_run:
447
+ run_task(client, task, SEED)
448
+
449
+
450
  if __name__ == "__main__":
451
  main()