| |
| |
| |
| |
| |
|
|
| """TB2 Environment Client.""" |
|
|
| from __future__ import annotations |
|
|
| from typing import Any |
|
|
|
|
| |
| try: |
| |
| from openenv.core.client_types import StepResult |
| from openenv.core.env_client import EnvClient |
|
|
| from .models import Tbench2Action, Tbench2Observation, Tbench2State |
| except ImportError: |
| from models import Tbench2Action, Tbench2Observation, Tbench2State |
|
|
| |
| from openenv.core.client_types import StepResult |
| from openenv.core.env_client import EnvClient |
|
|
|
|
| class Tbench2Env(EnvClient[Tbench2Action, Tbench2Observation, Tbench2State]): |
| """HTTP client for the TB2 environment.""" |
|
|
| def _step_payload(self, action: Tbench2Action) -> dict[str, Any]: |
| return { |
| "action_type": action.action_type, |
| "command": action.command, |
| "session_id": action.session_id, |
| "block": action.block, |
| "wait_seconds": action.wait_seconds, |
| "file_path": action.file_path, |
| "content": action.content, |
| } |
|
|
| def _parse_result(self, payload: dict[str, Any]) -> StepResult[Tbench2Observation]: |
| obs_data = payload.get("observation", {}) |
| observation = Tbench2Observation( |
| instruction=obs_data.get("instruction", ""), |
| output=obs_data.get("output", ""), |
| success=obs_data.get("success", True), |
| error=obs_data.get("error", ""), |
| task_id=obs_data.get("task_id", ""), |
| task_path=obs_data.get("task_path", ""), |
| session_id=obs_data.get("session_id"), |
| action_type=obs_data.get("action_type", ""), |
| info=obs_data.get("info", {}), |
| reward=payload.get("reward"), |
| done=payload.get("done", False), |
| metadata=obs_data.get("metadata", {}), |
| ) |
| return StepResult( |
| observation=observation, |
| reward=payload.get("reward"), |
| done=payload.get("done", False), |
| ) |
|
|
| def _parse_state(self, payload: dict[str, Any]) -> Tbench2State: |
| return Tbench2State( |
| episode_id=payload.get("episode_id"), |
| step_count=payload.get("step_count", 0), |
| task_id=payload.get("task_id", ""), |
| task_path=payload.get("task_path", ""), |
| terminal_ready=payload.get("terminal_ready", False), |
| last_action_type=payload.get("last_action_type", ""), |
| last_command=payload.get("last_command", ""), |
| last_output=payload.get("last_output", ""), |
| ) |
|
|