# # Copyright (c) Meta Platforms, Inc. and affiliates. # # All rights reserved. # # # # This source code is licensed under the BSD-style license found in the # # LICENSE file in the root directory of this source tree. # """Tool Use Env Environment Client.""" # from typing import Dict # from openenv.core import EnvClient # from openenv.core.client_types import StepResult # from openenv.core.env_server.types import State # from .models import ToolUseAction, ToolUseObservation # class ToolUseEnv( # EnvClient[ToolUseAction, ToolUseObservation, State] # ): # """ # Client for the Tool Use Env Environment. # This client maintains a persistent WebSocket connection to the environment server, # enabling efficient multi-step interactions with lower latency. # Each client instance has its own dedicated environment session on the server. # Example: # >>> # Connect to a running server # >>> with ToolUseEnv(base_url="http://localhost:8000") as client: # ... result = client.reset() # ... print(result.observation.echoed_message) # ... # ... result = client.step(ToolUseAction(message="Hello!")) # ... print(result.observation.echoed_message) # Example with Docker: # >>> # Automatically start container and connect # >>> client = ToolUseEnv.from_docker_image("tool_use_env-env:latest") # >>> try: # ... result = client.reset() # ... result = client.step(ToolUseAction(message="Test")) # ... finally: # ... client.close() # """ # def _step_payload(self, action: ToolUseAction) -> Dict: # """ # Convert ToolUseAction to JSON payload for step message. # Args: # action: ToolUseAction instance # Returns: # Dictionary representation suitable for JSON encoding # """ # return { # "message": action.message, # } # def _parse_result(self, payload: Dict) -> StepResult[ToolUseObservation]: # """ # Parse server response into StepResult[ToolUseObservation]. # Args: # payload: JSON response data from server # Returns: # StepResult with ToolUseObservation # """ # obs_data = payload.get("observation", {}) # observation = ToolUseObservation( # echoed_message=obs_data.get("echoed_message", ""), # message_length=obs_data.get("message_length", 0), # done=payload.get("done", False), # reward=payload.get("reward"), # metadata=obs_data.get("metadata", {}), # ) # return StepResult( # observation=observation, # reward=payload.get("reward"), # done=payload.get("done", False), # ) # def _parse_state(self, payload: Dict) -> State: # """ # Parse server response into State object. # Args: # payload: JSON response from state request # Returns: # State object with episode_id and step_count # """ # return State( # episode_id=payload.get("episode_id"), # step_count=payload.get("step_count", 0), # ) from openenv.core.env_client import EnvClient from openenv.core.client_types import StepResult from tool_use_env.models import ToolUseAction, ToolUseObservation, ToolUseState class ToolUseEnv(EnvClient[ToolUseAction, ToolUseObservation, ToolUseState]): def _step_payload(self, action: ToolUseAction) -> dict: return { "action_type": action.action_type, "artifact_id": action.artifact_id, "query": action.query, "message": action.message, "resolution_code": action.resolution_code, } def _parse_result(self, payload: dict) -> StepResult: obs_data = payload.get("observation", {}) observation = ToolUseObservation( done=payload.get("done", False), reward=payload.get("reward"), task_id=obs_data.get("task_id", ""), difficulty=obs_data.get("difficulty", "easy"), objective=obs_data.get("objective", ""), customer_message=obs_data.get("customer_message", ""), workspace_summary=obs_data.get("workspace_summary", ""), available_actions=obs_data.get("available_actions", []), available_resolution_codes=obs_data.get("available_resolution_codes", []), collected_evidence=obs_data.get("collected_evidence", []), last_tool_result=obs_data.get("last_tool_result"), last_action_error=obs_data.get("last_action_error"), remaining_steps=obs_data.get("remaining_steps", 0), current_score=obs_data.get("current_score", 0.0), metadata=obs_data.get("metadata", {}), ) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: dict) -> ToolUseState: return ToolUseState( episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), task_id=payload.get("task_id", ""), task_name=payload.get("task_name", ""), difficulty=payload.get("difficulty", ""), objective=payload.get("objective", ""), cumulative_reward=payload.get("cumulative_reward", 0.0), final_score=payload.get("final_score", 0.0), drafted_reply=payload.get("drafted_reply"), resolution_code=payload.get("resolution_code"), expected_resolution_code=payload.get("expected_resolution_code", ""), required_evidence=payload.get("required_evidence", []), collected_evidence=payload.get("collected_evidence", []), action_history=payload.get("action_history", []), repeat_action_count=payload.get("repeat_action_count", 0), last_action_error=payload.get("last_action_error"), known_artifacts=payload.get("known_artifacts", {}), known_policies=payload.get("known_policies", {}), )