Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- inference.py +58 -57
inference.py
CHANGED
|
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|
| 5 |
import asyncio
|
| 6 |
import json
|
| 7 |
import os
|
|
|
|
| 8 |
import textwrap
|
| 9 |
from dataclasses import dataclass, field
|
| 10 |
from typing import Any
|
|
@@ -65,6 +66,8 @@ class ToolRuntimeState:
|
|
| 65 |
|
| 66 |
method_log: list[dict[str, Any]] = field(default_factory=list)
|
| 67 |
last_preview: dict[str, Any] | None = None
|
|
|
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
def log_start(task: str, env: str, model: str) -> None:
|
|
@@ -72,23 +75,17 @@ def log_start(task: str, env: str, model: str) -> None:
|
|
| 72 |
|
| 73 |
|
| 74 |
def log_method(tool_name: str, arguments: dict[str, Any], note: str) -> None:
|
| 75 |
-
|
| 76 |
-
f"[METHOD] name={tool_name} args={json.dumps(arguments, sort_keys=True)} why={note}",
|
| 77 |
-
flush=True,
|
| 78 |
-
)
|
| 79 |
|
| 80 |
|
| 81 |
def log_payload_generator_methods(tool_name: str, generator_methods: list[dict[str, Any]]) -> None:
|
| 82 |
-
|
| 83 |
-
f"[PAYLOAD_GENERATOR_METHODS] source={tool_name} methods={json.dumps(generator_methods, sort_keys=True)}",
|
| 84 |
-
flush=True,
|
| 85 |
-
)
|
| 86 |
|
| 87 |
|
| 88 |
def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None:
|
| 89 |
error_val = error if error else "null"
|
| 90 |
print(
|
| 91 |
-
f"[STEP] step={step} action={action} reward={reward:.
|
| 92 |
flush=True,
|
| 93 |
)
|
| 94 |
|
|
@@ -98,17 +95,9 @@ def bounded_task_score(score: float) -> float:
|
|
| 98 |
return min(1.0 - SCORE_EPSILON, max(SCORE_EPSILON, score))
|
| 99 |
|
| 100 |
|
| 101 |
-
def log_end(success: bool, steps: int,
|
| 102 |
-
|
| 103 |
-
print(
|
| 104 |
-
f"[END] success={str(success).lower()} steps={steps} score={safe_score:.6f} methods={len(method_log)}",
|
| 105 |
-
flush=True,
|
| 106 |
-
)
|
| 107 |
-
print(json.dumps({"method_log": method_log}, indent=2), flush=True)
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
def log_task_boundary(task_id: str, difficulty: str, phase: str) -> None:
|
| 111 |
-
print(f"[TASK_{phase}] task_id={task_id} difficulty={difficulty}", flush=True)
|
| 112 |
|
| 113 |
|
| 114 |
def tool_schemas() -> list[dict[str, Any]]:
|
|
@@ -287,6 +276,34 @@ def preview_text(text: str, limit: int = 220) -> str:
|
|
| 287 |
return text.replace("\n", " ")[:limit]
|
| 288 |
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
async def connect_env() -> MetricTrackerRlEnv:
|
| 291 |
if BASE_URL:
|
| 292 |
client = MetricTrackerRlEnv(
|
|
@@ -315,6 +332,7 @@ async def execute_tool_call(
|
|
| 315 |
arguments: dict[str, Any],
|
| 316 |
) -> tuple[dict[str, Any], Any | None, MetricTrackerRlObservation]:
|
| 317 |
"""Execute one model-requested tool locally."""
|
|
|
|
| 318 |
if tool_name == "submit_payload_generator":
|
| 319 |
methods = [
|
| 320 |
PayloadGeneratorMethod(**item)
|
|
@@ -329,6 +347,7 @@ async def execute_tool_call(
|
|
| 329 |
}
|
| 330 |
)
|
| 331 |
result = await env.step(MetricTrackerRlAction(payload_generators=methods))
|
|
|
|
| 332 |
return (
|
| 333 |
{
|
| 334 |
"status": result.observation.status,
|
|
@@ -349,6 +368,7 @@ async def execute_tool_call(
|
|
| 349 |
if tool_name == "submit_solution":
|
| 350 |
rows = [MetricSubmissionRow(**row) for row in arguments.get("rows", [])]
|
| 351 |
result = await env.step(MetricTrackerRlAction(classifications=rows))
|
|
|
|
| 352 |
return (
|
| 353 |
{
|
| 354 |
"status": result.observation.status,
|
|
@@ -373,6 +393,7 @@ async def execute_tool_call(
|
|
| 373 |
analysis_args=arguments,
|
| 374 |
)
|
| 375 |
)
|
|
|
|
| 376 |
output = result.observation.analysis_result or {
|
| 377 |
"method": tool_name,
|
| 378 |
"arguments": arguments,
|
|
@@ -428,7 +449,7 @@ async def run_agent_loop(
|
|
| 428 |
client: OpenAI,
|
| 429 |
env: MetricTrackerRlEnv,
|
| 430 |
observation: MetricTrackerRlObservation,
|
| 431 |
-
) -> tuple[Any, str, int, list[dict[str, Any]]]:
|
| 432 |
"""Run a tool-calling loop until the env is solved or the round limit is hit."""
|
| 433 |
runtime_state = ToolRuntimeState()
|
| 434 |
current_observation = observation
|
|
@@ -514,7 +535,7 @@ async def run_agent_loop(
|
|
| 514 |
final_text = (completion.choices[0].message.content or "").strip()
|
| 515 |
break
|
| 516 |
|
| 517 |
-
return last_result, final_text, tool_rounds, runtime_state.method_log
|
| 518 |
|
| 519 |
|
| 520 |
async def run_single_task(
|
|
@@ -524,9 +545,9 @@ async def run_single_task(
|
|
| 524 |
) -> dict[str, Any]:
|
| 525 |
"""Run one named benchmark task and return a reproducible summary."""
|
| 526 |
task_spec = get_task_spec(task_id)
|
| 527 |
-
|
| 528 |
reset_result = await env.reset(task_id=task_spec.task_id)
|
| 529 |
-
final_result, final_text, tool_rounds, method_log = await run_agent_loop(
|
| 530 |
client,
|
| 531 |
env,
|
| 532 |
reset_result.observation,
|
|
@@ -537,15 +558,6 @@ async def run_single_task(
|
|
| 537 |
reward = float(final_result.reward or 0.0)
|
| 538 |
task_score = bounded_task_score(reward)
|
| 539 |
success = bool(final_result.done and reward >= 0.999999)
|
| 540 |
-
log_step(
|
| 541 |
-
step=1,
|
| 542 |
-
action=preview_text(final_text or "graded_submission"),
|
| 543 |
-
reward=reward,
|
| 544 |
-
done=bool(final_result.done),
|
| 545 |
-
error=None,
|
| 546 |
-
)
|
| 547 |
-
log_end(success=success, steps=1, score=task_score, method_log=method_log)
|
| 548 |
-
log_task_boundary(task_spec.task_id, task_spec.difficulty, "END")
|
| 549 |
return {
|
| 550 |
"task_id": task_spec.task_id,
|
| 551 |
"difficulty": task_spec.difficulty,
|
|
@@ -562,6 +574,8 @@ async def run_single_task(
|
|
| 562 |
"expected_row_count": final_result.observation.expected_row_count,
|
| 563 |
"tool_rounds": tool_rounds,
|
| 564 |
"assistant_summary": final_text,
|
|
|
|
|
|
|
| 565 |
"reward_breakdown": (
|
| 566 |
final_result.observation.reward_breakdown.model_dump()
|
| 567 |
if final_result.observation.reward_breakdown
|
|
@@ -580,9 +594,16 @@ async def run_single_task_with_retries(
|
|
| 580 |
|
| 581 |
for attempt in range(1, attempts + 1):
|
| 582 |
env = None
|
|
|
|
|
|
|
|
|
|
| 583 |
try:
|
| 584 |
env = await connect_env()
|
| 585 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 586 |
except (ConnectionClosedError, ConnectionError, TimeoutError, OSError) as exc:
|
| 587 |
last_error = exc
|
| 588 |
print(
|
|
@@ -591,6 +612,7 @@ async def run_single_task_with_retries(
|
|
| 591 |
f"env_connection_error={type(exc).__name__}: {exc}"
|
| 592 |
),
|
| 593 |
flush=True,
|
|
|
|
| 594 |
)
|
| 595 |
if attempt >= attempts:
|
| 596 |
raise
|
|
@@ -600,6 +622,8 @@ async def run_single_task_with_retries(
|
|
| 600 |
await env.close()
|
| 601 |
except Exception:
|
| 602 |
pass
|
|
|
|
|
|
|
| 603 |
|
| 604 |
assert last_error is not None
|
| 605 |
raise last_error
|
|
@@ -612,32 +636,9 @@ async def main() -> None:
|
|
| 612 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 613 |
task_summaries: list[dict[str, Any]] = []
|
| 614 |
|
| 615 |
-
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
|
| 616 |
-
|
| 617 |
for task_id in DEFAULT_TASK_ORDER:
|
| 618 |
task_summaries.append(await run_single_task_with_retries(client, task_id))
|
| 619 |
|
| 620 |
-
average_score = (
|
| 621 |
-
round(sum(item["score"] for item in task_summaries) / len(task_summaries), 6)
|
| 622 |
-
if task_summaries
|
| 623 |
-
else 0.0
|
| 624 |
-
)
|
| 625 |
-
print(
|
| 626 |
-
json.dumps(
|
| 627 |
-
{
|
| 628 |
-
"benchmark": BENCHMARK,
|
| 629 |
-
"model": MODEL_NAME,
|
| 630 |
-
"task_count": len(task_summaries),
|
| 631 |
-
"task_ids": [item["task_id"] for item in task_summaries],
|
| 632 |
-
"average_score": average_score,
|
| 633 |
-
"successful_tasks": sum(1 for item in task_summaries if item["success"]),
|
| 634 |
-
"tasks": task_summaries,
|
| 635 |
-
},
|
| 636 |
-
indent=2,
|
| 637 |
-
),
|
| 638 |
-
flush=True,
|
| 639 |
-
)
|
| 640 |
-
|
| 641 |
|
| 642 |
if __name__ == "__main__":
|
| 643 |
asyncio.run(main())
|
|
|
|
| 5 |
import asyncio
|
| 6 |
import json
|
| 7 |
import os
|
| 8 |
+
import sys
|
| 9 |
import textwrap
|
| 10 |
from dataclasses import dataclass, field
|
| 11 |
from typing import Any
|
|
|
|
| 66 |
|
| 67 |
method_log: list[dict[str, Any]] = field(default_factory=list)
|
| 68 |
last_preview: dict[str, Any] | None = None
|
| 69 |
+
rewards: list[float] = field(default_factory=list)
|
| 70 |
+
steps: int = 0
|
| 71 |
|
| 72 |
|
| 73 |
def log_start(task: str, env: str, model: str) -> None:
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
def log_method(tool_name: str, arguments: dict[str, Any], note: str) -> None:
|
| 78 |
+
return None
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
def log_payload_generator_methods(tool_name: str, generator_methods: list[dict[str, Any]]) -> None:
|
| 82 |
+
return None
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
|
| 85 |
def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None:
|
| 86 |
error_val = error if error else "null"
|
| 87 |
print(
|
| 88 |
+
f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_val}",
|
| 89 |
flush=True,
|
| 90 |
)
|
| 91 |
|
|
|
|
| 95 |
return min(1.0 - SCORE_EPSILON, max(SCORE_EPSILON, score))
|
| 96 |
|
| 97 |
|
| 98 |
+
def log_end(success: bool, steps: int, rewards: list[float]) -> None:
|
| 99 |
+
rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
|
| 100 |
+
print(f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
|
| 103 |
def tool_schemas() -> list[dict[str, Any]]:
|
|
|
|
| 276 |
return text.replace("\n", " ")[:limit]
|
| 277 |
|
| 278 |
|
| 279 |
+
def format_action(tool_name: str, arguments: dict[str, Any]) -> str:
|
| 280 |
+
if not arguments:
|
| 281 |
+
return f"{tool_name}()"
|
| 282 |
+
return preview_text(f"{tool_name}({json.dumps(arguments, sort_keys=True)})")
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def step_error(result: Any) -> str | None:
|
| 286 |
+
message = getattr(result.observation, "message", None)
|
| 287 |
+
return message if result.observation.status == "error" and message else None
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def record_step(
|
| 291 |
+
runtime_state: ToolRuntimeState,
|
| 292 |
+
action: str,
|
| 293 |
+
result: Any,
|
| 294 |
+
) -> None:
|
| 295 |
+
reward = float(result.reward or 0.0)
|
| 296 |
+
runtime_state.steps += 1
|
| 297 |
+
runtime_state.rewards.append(reward)
|
| 298 |
+
log_step(
|
| 299 |
+
step=runtime_state.steps,
|
| 300 |
+
action=action,
|
| 301 |
+
reward=reward,
|
| 302 |
+
done=bool(result.done),
|
| 303 |
+
error=step_error(result),
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
async def connect_env() -> MetricTrackerRlEnv:
|
| 308 |
if BASE_URL:
|
| 309 |
client = MetricTrackerRlEnv(
|
|
|
|
| 332 |
arguments: dict[str, Any],
|
| 333 |
) -> tuple[dict[str, Any], Any | None, MetricTrackerRlObservation]:
|
| 334 |
"""Execute one model-requested tool locally."""
|
| 335 |
+
action = format_action(tool_name, arguments)
|
| 336 |
if tool_name == "submit_payload_generator":
|
| 337 |
methods = [
|
| 338 |
PayloadGeneratorMethod(**item)
|
|
|
|
| 347 |
}
|
| 348 |
)
|
| 349 |
result = await env.step(MetricTrackerRlAction(payload_generators=methods))
|
| 350 |
+
record_step(runtime_state, action, result)
|
| 351 |
return (
|
| 352 |
{
|
| 353 |
"status": result.observation.status,
|
|
|
|
| 368 |
if tool_name == "submit_solution":
|
| 369 |
rows = [MetricSubmissionRow(**row) for row in arguments.get("rows", [])]
|
| 370 |
result = await env.step(MetricTrackerRlAction(classifications=rows))
|
| 371 |
+
record_step(runtime_state, action, result)
|
| 372 |
return (
|
| 373 |
{
|
| 374 |
"status": result.observation.status,
|
|
|
|
| 393 |
analysis_args=arguments,
|
| 394 |
)
|
| 395 |
)
|
| 396 |
+
record_step(runtime_state, action, result)
|
| 397 |
output = result.observation.analysis_result or {
|
| 398 |
"method": tool_name,
|
| 399 |
"arguments": arguments,
|
|
|
|
| 449 |
client: OpenAI,
|
| 450 |
env: MetricTrackerRlEnv,
|
| 451 |
observation: MetricTrackerRlObservation,
|
| 452 |
+
) -> tuple[Any, str, int, list[dict[str, Any]], ToolRuntimeState]:
|
| 453 |
"""Run a tool-calling loop until the env is solved or the round limit is hit."""
|
| 454 |
runtime_state = ToolRuntimeState()
|
| 455 |
current_observation = observation
|
|
|
|
| 535 |
final_text = (completion.choices[0].message.content or "").strip()
|
| 536 |
break
|
| 537 |
|
| 538 |
+
return last_result, final_text, tool_rounds, runtime_state.method_log, runtime_state
|
| 539 |
|
| 540 |
|
| 541 |
async def run_single_task(
|
|
|
|
| 545 |
) -> dict[str, Any]:
|
| 546 |
"""Run one named benchmark task and return a reproducible summary."""
|
| 547 |
task_spec = get_task_spec(task_id)
|
| 548 |
+
log_start(task=task_spec.task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 549 |
reset_result = await env.reset(task_id=task_spec.task_id)
|
| 550 |
+
final_result, final_text, tool_rounds, method_log, runtime_state = await run_agent_loop(
|
| 551 |
client,
|
| 552 |
env,
|
| 553 |
reset_result.observation,
|
|
|
|
| 558 |
reward = float(final_result.reward or 0.0)
|
| 559 |
task_score = bounded_task_score(reward)
|
| 560 |
success = bool(final_result.done and reward >= 0.999999)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
return {
|
| 562 |
"task_id": task_spec.task_id,
|
| 563 |
"difficulty": task_spec.difficulty,
|
|
|
|
| 574 |
"expected_row_count": final_result.observation.expected_row_count,
|
| 575 |
"tool_rounds": tool_rounds,
|
| 576 |
"assistant_summary": final_text,
|
| 577 |
+
"steps": runtime_state.steps,
|
| 578 |
+
"rewards": runtime_state.rewards,
|
| 579 |
"reward_breakdown": (
|
| 580 |
final_result.observation.reward_breakdown.model_dump()
|
| 581 |
if final_result.observation.reward_breakdown
|
|
|
|
| 594 |
|
| 595 |
for attempt in range(1, attempts + 1):
|
| 596 |
env = None
|
| 597 |
+
success = False
|
| 598 |
+
steps = 0
|
| 599 |
+
rewards: list[float] = []
|
| 600 |
try:
|
| 601 |
env = await connect_env()
|
| 602 |
+
summary = await run_single_task(client, env, task_id)
|
| 603 |
+
success = bool(summary["success"])
|
| 604 |
+
steps = int(summary["steps"])
|
| 605 |
+
rewards = list(summary["rewards"])
|
| 606 |
+
return summary
|
| 607 |
except (ConnectionClosedError, ConnectionError, TimeoutError, OSError) as exc:
|
| 608 |
last_error = exc
|
| 609 |
print(
|
|
|
|
| 612 |
f"env_connection_error={type(exc).__name__}: {exc}"
|
| 613 |
),
|
| 614 |
flush=True,
|
| 615 |
+
file=sys.stderr,
|
| 616 |
)
|
| 617 |
if attempt >= attempts:
|
| 618 |
raise
|
|
|
|
| 622 |
await env.close()
|
| 623 |
except Exception:
|
| 624 |
pass
|
| 625 |
+
if env is not None:
|
| 626 |
+
log_end(success=success, steps=steps, rewards=rewards)
|
| 627 |
|
| 628 |
assert last_error is not None
|
| 629 |
raise last_error
|
|
|
|
| 636 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 637 |
task_summaries: list[dict[str, Any]] = []
|
| 638 |
|
|
|
|
|
|
|
| 639 |
for task_id in DEFAULT_TASK_ORDER:
|
| 640 |
task_summaries.append(await run_single_task_with_retries(client, task_id))
|
| 641 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 642 |
|
| 643 |
if __name__ == "__main__":
|
| 644 |
asyncio.run(main())
|