kushalExplores commited on
Commit
15a0c0f
·
verified ·
1 Parent(s): 891d2a4

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +58 -57
inference.py CHANGED
@@ -5,6 +5,7 @@ from __future__ import annotations
5
  import asyncio
6
  import json
7
  import os
 
8
  import textwrap
9
  from dataclasses import dataclass, field
10
  from typing import Any
@@ -65,6 +66,8 @@ class ToolRuntimeState:
65
 
66
  method_log: list[dict[str, Any]] = field(default_factory=list)
67
  last_preview: dict[str, Any] | None = None
 
 
68
 
69
 
70
  def log_start(task: str, env: str, model: str) -> None:
@@ -72,23 +75,17 @@ def log_start(task: str, env: str, model: str) -> None:
72
 
73
 
74
  def log_method(tool_name: str, arguments: dict[str, Any], note: str) -> None:
75
- print(
76
- f"[METHOD] name={tool_name} args={json.dumps(arguments, sort_keys=True)} why={note}",
77
- flush=True,
78
- )
79
 
80
 
81
  def log_payload_generator_methods(tool_name: str, generator_methods: list[dict[str, Any]]) -> None:
82
- print(
83
- f"[PAYLOAD_GENERATOR_METHODS] source={tool_name} methods={json.dumps(generator_methods, sort_keys=True)}",
84
- flush=True,
85
- )
86
 
87
 
88
  def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None:
89
  error_val = error if error else "null"
90
  print(
91
- f"[STEP] step={step} action={action} reward={reward:.3f} done={str(done).lower()} error={error_val}",
92
  flush=True,
93
  )
94
 
@@ -98,17 +95,9 @@ def bounded_task_score(score: float) -> float:
98
  return min(1.0 - SCORE_EPSILON, max(SCORE_EPSILON, score))
99
 
100
 
101
- def log_end(success: bool, steps: int, score: float, method_log: list[dict[str, Any]]) -> None:
102
- safe_score = bounded_task_score(score)
103
- print(
104
- f"[END] success={str(success).lower()} steps={steps} score={safe_score:.6f} methods={len(method_log)}",
105
- flush=True,
106
- )
107
- print(json.dumps({"method_log": method_log}, indent=2), flush=True)
108
-
109
-
110
- def log_task_boundary(task_id: str, difficulty: str, phase: str) -> None:
111
- print(f"[TASK_{phase}] task_id={task_id} difficulty={difficulty}", flush=True)
112
 
113
 
114
  def tool_schemas() -> list[dict[str, Any]]:
@@ -287,6 +276,34 @@ def preview_text(text: str, limit: int = 220) -> str:
287
  return text.replace("\n", " ")[:limit]
288
 
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  async def connect_env() -> MetricTrackerRlEnv:
291
  if BASE_URL:
292
  client = MetricTrackerRlEnv(
@@ -315,6 +332,7 @@ async def execute_tool_call(
315
  arguments: dict[str, Any],
316
  ) -> tuple[dict[str, Any], Any | None, MetricTrackerRlObservation]:
317
  """Execute one model-requested tool locally."""
 
318
  if tool_name == "submit_payload_generator":
319
  methods = [
320
  PayloadGeneratorMethod(**item)
@@ -329,6 +347,7 @@ async def execute_tool_call(
329
  }
330
  )
331
  result = await env.step(MetricTrackerRlAction(payload_generators=methods))
 
332
  return (
333
  {
334
  "status": result.observation.status,
@@ -349,6 +368,7 @@ async def execute_tool_call(
349
  if tool_name == "submit_solution":
350
  rows = [MetricSubmissionRow(**row) for row in arguments.get("rows", [])]
351
  result = await env.step(MetricTrackerRlAction(classifications=rows))
 
352
  return (
353
  {
354
  "status": result.observation.status,
@@ -373,6 +393,7 @@ async def execute_tool_call(
373
  analysis_args=arguments,
374
  )
375
  )
 
376
  output = result.observation.analysis_result or {
377
  "method": tool_name,
378
  "arguments": arguments,
@@ -428,7 +449,7 @@ async def run_agent_loop(
428
  client: OpenAI,
429
  env: MetricTrackerRlEnv,
430
  observation: MetricTrackerRlObservation,
431
- ) -> tuple[Any, str, int, list[dict[str, Any]]]:
432
  """Run a tool-calling loop until the env is solved or the round limit is hit."""
433
  runtime_state = ToolRuntimeState()
434
  current_observation = observation
@@ -514,7 +535,7 @@ async def run_agent_loop(
514
  final_text = (completion.choices[0].message.content or "").strip()
515
  break
516
 
517
- return last_result, final_text, tool_rounds, runtime_state.method_log
518
 
519
 
520
  async def run_single_task(
@@ -524,9 +545,9 @@ async def run_single_task(
524
  ) -> dict[str, Any]:
525
  """Run one named benchmark task and return a reproducible summary."""
526
  task_spec = get_task_spec(task_id)
527
- log_task_boundary(task_spec.task_id, task_spec.difficulty, "START")
528
  reset_result = await env.reset(task_id=task_spec.task_id)
529
- final_result, final_text, tool_rounds, method_log = await run_agent_loop(
530
  client,
531
  env,
532
  reset_result.observation,
@@ -537,15 +558,6 @@ async def run_single_task(
537
  reward = float(final_result.reward or 0.0)
538
  task_score = bounded_task_score(reward)
539
  success = bool(final_result.done and reward >= 0.999999)
540
- log_step(
541
- step=1,
542
- action=preview_text(final_text or "graded_submission"),
543
- reward=reward,
544
- done=bool(final_result.done),
545
- error=None,
546
- )
547
- log_end(success=success, steps=1, score=task_score, method_log=method_log)
548
- log_task_boundary(task_spec.task_id, task_spec.difficulty, "END")
549
  return {
550
  "task_id": task_spec.task_id,
551
  "difficulty": task_spec.difficulty,
@@ -562,6 +574,8 @@ async def run_single_task(
562
  "expected_row_count": final_result.observation.expected_row_count,
563
  "tool_rounds": tool_rounds,
564
  "assistant_summary": final_text,
 
 
565
  "reward_breakdown": (
566
  final_result.observation.reward_breakdown.model_dump()
567
  if final_result.observation.reward_breakdown
@@ -580,9 +594,16 @@ async def run_single_task_with_retries(
580
 
581
  for attempt in range(1, attempts + 1):
582
  env = None
 
 
 
583
  try:
584
  env = await connect_env()
585
- return await run_single_task(client, env, task_id)
 
 
 
 
586
  except (ConnectionClosedError, ConnectionError, TimeoutError, OSError) as exc:
587
  last_error = exc
588
  print(
@@ -591,6 +612,7 @@ async def run_single_task_with_retries(
591
  f"env_connection_error={type(exc).__name__}: {exc}"
592
  ),
593
  flush=True,
 
594
  )
595
  if attempt >= attempts:
596
  raise
@@ -600,6 +622,8 @@ async def run_single_task_with_retries(
600
  await env.close()
601
  except Exception:
602
  pass
 
 
603
 
604
  assert last_error is not None
605
  raise last_error
@@ -612,32 +636,9 @@ async def main() -> None:
612
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
613
  task_summaries: list[dict[str, Any]] = []
614
 
615
- log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
616
-
617
  for task_id in DEFAULT_TASK_ORDER:
618
  task_summaries.append(await run_single_task_with_retries(client, task_id))
619
 
620
- average_score = (
621
- round(sum(item["score"] for item in task_summaries) / len(task_summaries), 6)
622
- if task_summaries
623
- else 0.0
624
- )
625
- print(
626
- json.dumps(
627
- {
628
- "benchmark": BENCHMARK,
629
- "model": MODEL_NAME,
630
- "task_count": len(task_summaries),
631
- "task_ids": [item["task_id"] for item in task_summaries],
632
- "average_score": average_score,
633
- "successful_tasks": sum(1 for item in task_summaries if item["success"]),
634
- "tasks": task_summaries,
635
- },
636
- indent=2,
637
- ),
638
- flush=True,
639
- )
640
-
641
 
642
  if __name__ == "__main__":
643
  asyncio.run(main())
 
5
  import asyncio
6
  import json
7
  import os
8
+ import sys
9
  import textwrap
10
  from dataclasses import dataclass, field
11
  from typing import Any
 
66
 
67
  method_log: list[dict[str, Any]] = field(default_factory=list)
68
  last_preview: dict[str, Any] | None = None
69
+ rewards: list[float] = field(default_factory=list)
70
+ steps: int = 0
71
 
72
 
73
  def log_start(task: str, env: str, model: str) -> None:
 
75
 
76
 
77
  def log_method(tool_name: str, arguments: dict[str, Any], note: str) -> None:
78
+ return None
 
 
 
79
 
80
 
81
  def log_payload_generator_methods(tool_name: str, generator_methods: list[dict[str, Any]]) -> None:
82
+ return None
 
 
 
83
 
84
 
85
  def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None:
86
  error_val = error if error else "null"
87
  print(
88
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_val}",
89
  flush=True,
90
  )
91
 
 
95
  return min(1.0 - SCORE_EPSILON, max(SCORE_EPSILON, score))
96
 
97
 
98
+ def log_end(success: bool, steps: int, rewards: list[float]) -> None:
99
+ rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
100
+ print(f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}", flush=True)
 
 
 
 
 
 
 
 
101
 
102
 
103
  def tool_schemas() -> list[dict[str, Any]]:
 
276
  return text.replace("\n", " ")[:limit]
277
 
278
 
279
+ def format_action(tool_name: str, arguments: dict[str, Any]) -> str:
280
+ if not arguments:
281
+ return f"{tool_name}()"
282
+ return preview_text(f"{tool_name}({json.dumps(arguments, sort_keys=True)})")
283
+
284
+
285
+ def step_error(result: Any) -> str | None:
286
+ message = getattr(result.observation, "message", None)
287
+ return message if result.observation.status == "error" and message else None
288
+
289
+
290
+ def record_step(
291
+ runtime_state: ToolRuntimeState,
292
+ action: str,
293
+ result: Any,
294
+ ) -> None:
295
+ reward = float(result.reward or 0.0)
296
+ runtime_state.steps += 1
297
+ runtime_state.rewards.append(reward)
298
+ log_step(
299
+ step=runtime_state.steps,
300
+ action=action,
301
+ reward=reward,
302
+ done=bool(result.done),
303
+ error=step_error(result),
304
+ )
305
+
306
+
307
  async def connect_env() -> MetricTrackerRlEnv:
308
  if BASE_URL:
309
  client = MetricTrackerRlEnv(
 
332
  arguments: dict[str, Any],
333
  ) -> tuple[dict[str, Any], Any | None, MetricTrackerRlObservation]:
334
  """Execute one model-requested tool locally."""
335
+ action = format_action(tool_name, arguments)
336
  if tool_name == "submit_payload_generator":
337
  methods = [
338
  PayloadGeneratorMethod(**item)
 
347
  }
348
  )
349
  result = await env.step(MetricTrackerRlAction(payload_generators=methods))
350
+ record_step(runtime_state, action, result)
351
  return (
352
  {
353
  "status": result.observation.status,
 
368
  if tool_name == "submit_solution":
369
  rows = [MetricSubmissionRow(**row) for row in arguments.get("rows", [])]
370
  result = await env.step(MetricTrackerRlAction(classifications=rows))
371
+ record_step(runtime_state, action, result)
372
  return (
373
  {
374
  "status": result.observation.status,
 
393
  analysis_args=arguments,
394
  )
395
  )
396
+ record_step(runtime_state, action, result)
397
  output = result.observation.analysis_result or {
398
  "method": tool_name,
399
  "arguments": arguments,
 
449
  client: OpenAI,
450
  env: MetricTrackerRlEnv,
451
  observation: MetricTrackerRlObservation,
452
+ ) -> tuple[Any, str, int, list[dict[str, Any]], ToolRuntimeState]:
453
  """Run a tool-calling loop until the env is solved or the round limit is hit."""
454
  runtime_state = ToolRuntimeState()
455
  current_observation = observation
 
535
  final_text = (completion.choices[0].message.content or "").strip()
536
  break
537
 
538
+ return last_result, final_text, tool_rounds, runtime_state.method_log, runtime_state
539
 
540
 
541
  async def run_single_task(
 
545
  ) -> dict[str, Any]:
546
  """Run one named benchmark task and return a reproducible summary."""
547
  task_spec = get_task_spec(task_id)
548
+ log_start(task=task_spec.task_id, env=BENCHMARK, model=MODEL_NAME)
549
  reset_result = await env.reset(task_id=task_spec.task_id)
550
+ final_result, final_text, tool_rounds, method_log, runtime_state = await run_agent_loop(
551
  client,
552
  env,
553
  reset_result.observation,
 
558
  reward = float(final_result.reward or 0.0)
559
  task_score = bounded_task_score(reward)
560
  success = bool(final_result.done and reward >= 0.999999)
 
 
 
 
 
 
 
 
 
561
  return {
562
  "task_id": task_spec.task_id,
563
  "difficulty": task_spec.difficulty,
 
574
  "expected_row_count": final_result.observation.expected_row_count,
575
  "tool_rounds": tool_rounds,
576
  "assistant_summary": final_text,
577
+ "steps": runtime_state.steps,
578
+ "rewards": runtime_state.rewards,
579
  "reward_breakdown": (
580
  final_result.observation.reward_breakdown.model_dump()
581
  if final_result.observation.reward_breakdown
 
594
 
595
  for attempt in range(1, attempts + 1):
596
  env = None
597
+ success = False
598
+ steps = 0
599
+ rewards: list[float] = []
600
  try:
601
  env = await connect_env()
602
+ summary = await run_single_task(client, env, task_id)
603
+ success = bool(summary["success"])
604
+ steps = int(summary["steps"])
605
+ rewards = list(summary["rewards"])
606
+ return summary
607
  except (ConnectionClosedError, ConnectionError, TimeoutError, OSError) as exc:
608
  last_error = exc
609
  print(
 
612
  f"env_connection_error={type(exc).__name__}: {exc}"
613
  ),
614
  flush=True,
615
+ file=sys.stderr,
616
  )
617
  if attempt >= attempts:
618
  raise
 
622
  await env.close()
623
  except Exception:
624
  pass
625
+ if env is not None:
626
+ log_end(success=success, steps=steps, rewards=rewards)
627
 
628
  assert last_error is not None
629
  raise last_error
 
636
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
637
  task_summaries: list[dict[str, Any]] = []
638
 
 
 
639
  for task_id in DEFAULT_TASK_ORDER:
640
  task_summaries.append(await run_single_task_with_retries(client, task_id))
641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642
 
643
  if __name__ == "__main__":
644
  asyncio.run(main())