# # Copyright (c) Meta Platforms, Inc. and affiliates. # # All rights reserved. # # # # This source code is licensed under the BSD-style license found in the # # LICENSE file in the root directory of this source tree. # """Tool Use Env Environment Client.""" # from typing import Dict # from openenv.core import EnvClient # from openenv.core.client_types import StepResult # from openenv.core.env_server.types import State # from .models import ToolUseAction, ToolUseObservation # class ToolUseEnv( # EnvClient[ToolUseAction, ToolUseObservation, State] # ): # """ # Client for the Tool Use Env Environment. # This client maintains a persistent WebSocket connection to the environment server, # enabling efficient multi-step interactions with lower latency. # Each client instance has its own dedicated environment session on the server. # Example: # >>> # Connect to a running server # >>> with ToolUseEnv(base_url="http://localhost:8000") as client: # ... result = client.reset() # ... print(result.observation.echoed_message) # ... # ... result = client.step(ToolUseAction(message="Hello!")) # ... print(result.observation.echoed_message) # Example with Docker: # >>> # Automatically start container and connect # >>> client = ToolUseEnv.from_docker_image("tool_use_env-env:latest") # >>> try: # ... result = client.reset() # ... result = client.step(ToolUseAction(message="Test")) # ... finally: # ... client.close() # """ # def _step_payload(self, action: ToolUseAction) -> Dict: # """ # Convert ToolUseAction to JSON payload for step message. # Args: # action: ToolUseAction instance # Returns: # Dictionary representation suitable for JSON encoding # """ # return { # "message": action.message, # } # def _parse_result(self, payload: Dict) -> StepResult[ToolUseObservation]: # """ # Parse server response into StepResult[ToolUseObservation]. # Args: # payload: JSON response data from server # Returns: # StepResult with ToolUseObservation # """ # obs_data = payload.get("observation", {}) # observation = ToolUseObservation( # echoed_message=obs_data.get("echoed_message", ""), # message_length=obs_data.get("message_length", 0), # done=payload.get("done", False), # reward=payload.get("reward"), # metadata=obs_data.get("metadata", {}), # ) # return StepResult( # observation=observation, # reward=payload.get("reward"), # done=payload.get("done", False), # ) # def _parse_state(self, payload: Dict) -> State: # """ # Parse server response into State object. # Args: # payload: JSON response from state request # Returns: # State object with episode_id and step_count # """ # return State( # episode_id=payload.get("episode_id"), # step_count=payload.get("step_count", 0), # ) from openenv.core.env_client import EnvClient from openenv.core.client_types import StepResult from tool_use_env.models import ToolUseAction, ToolUseObservation, ToolUseState class ToolUseEnv(EnvClient[ToolUseAction, ToolUseObservation, ToolUseState]): def _step_payload(self, action: ToolUseAction) -> dict: return { "action_type": action.action_type } def _parse_result(self, payload: dict) -> StepResult: obs_data = payload.get("observation", {}) observation = ToolUseObservation( done=payload.get("done", False), reward=payload.get("reward"), query=obs_data.get("query", ""), tool_output=obs_data.get("tool_output"), message=obs_data.get("message", "") ) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: dict) -> ToolUseState: return ToolUseState( episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), current_query=payload.get("current_query", ""), correct_action=payload.get("correct_action", ""), correct_answer=payload.get("correct_answer", ""), difficulty=payload.get("difficulty", "") )