Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Tests for mode selection in OpenEnv clients and environments. | |
| This file combines two aspects of mode selection: | |
| 1. Client mode selection (from main): Tests for selecting between WebSocket (Gym-style) | |
| and MCP modes through constructor parameters and environment variables. The mode | |
| selection determines which protocol the client uses to communicate with the server. | |
| 2. Environment code mode (from issue #347): Tests for mode selection between tool-calling | |
| and code mode. Per RFC 003, MCP environments should support two modes: | |
| - Tool-calling mode: one tool call per step (traditional MCP) | |
| - Code mode: code blocks with direct Python function calls (CodeAct pattern) | |
| Test coverage: | |
| - Client: Mode selection via constructor parameter and environment variable | |
| - Client: GenericEnvClient and MCPToolClient mode behavior | |
| - Environment: Code mode with get_callables() and execute_code() | |
| - Environment: Code mode with mode-aware tool registration | |
| """ | |
| import os | |
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| from fastmcp import FastMCP | |
| from openenv.core.env_server.mcp_environment import MCPEnvironment | |
| from openenv.core.env_server.mcp_types import ListToolsAction, ListToolsObservation | |
| from openenv.core.env_server.types import Observation, State | |
| from openenv.core.generic_client import GenericEnvClient | |
| from openenv.core.mcp_client import MCPToolClient | |
| # ============================================================================ | |
| # Client Mode Selection Tests (from main) | |
| # ============================================================================ | |
| # ============================================================================ | |
| # Test Fixtures - Client Mode Tests | |
| # ============================================================================ | |
| def clean_env(): | |
| """Ensure OPENENV_CLIENT_MODE is not set.""" | |
| old_mode = os.environ.pop("OPENENV_CLIENT_MODE", None) | |
| yield | |
| if old_mode is not None: | |
| os.environ["OPENENV_CLIENT_MODE"] = old_mode | |
| def mock_websocket(): | |
| """Create a mock WebSocket connection.""" | |
| ws = MagicMock() | |
| ws.recv.return_value = '{"type": "response", "data": {}}' | |
| return ws | |
| # ============================================================================ | |
| # Constructor Parameter Mode Selection Tests | |
| # ============================================================================ | |
| class TestConstructorModeSelection: | |
| """Test mode selection via constructor parameter.""" | |
| def test_default_mode_is_simulation(self, clean_env): | |
| """Test that default mode is 'simulation' when no mode specified.""" | |
| client = GenericEnvClient(base_url="http://localhost:8000") | |
| # Should have simulation mode set | |
| assert hasattr(client, "_mode") | |
| assert client._mode == "simulation" | |
| def test_explicit_simulation_mode(self, clean_env): | |
| """Test explicit simulation mode via constructor.""" | |
| client = GenericEnvClient(base_url="http://localhost:8000", mode="simulation") | |
| assert client._mode == "simulation" | |
| def test_explicit_production_mode(self, clean_env): | |
| """Test explicit production mode via constructor.""" | |
| client = GenericEnvClient(base_url="http://localhost:8000", mode="production") | |
| assert client._mode == "production" | |
| def test_invalid_mode_raises_error(self, clean_env): | |
| """Test that invalid mode value raises ValueError.""" | |
| with pytest.raises(ValueError) as exc_info: | |
| GenericEnvClient(base_url="http://localhost:8000", mode="invalid_mode") | |
| assert "mode" in str(exc_info.value).lower() | |
| assert "simulation" in str(exc_info.value).lower() | |
| assert "production" in str(exc_info.value).lower() | |
| def test_case_insensitive_mode(self, clean_env): | |
| """Test that mode parameter is case-insensitive.""" | |
| client1 = GenericEnvClient(base_url="http://localhost:8000", mode="SIMULATION") | |
| client2 = GenericEnvClient(base_url="http://localhost:8000", mode="PRODUCTION") | |
| assert client1._mode == "simulation" | |
| assert client2._mode == "production" | |
| # ============================================================================ | |
| # Environment Variable Mode Selection Tests | |
| # ============================================================================ | |
| class TestEnvironmentVariableModeSelection: | |
| """Test mode selection via OPENENV_CLIENT_MODE environment variable.""" | |
| def test_env_var_simulation_mode(self): | |
| """Test mode selection via OPENENV_CLIENT_MODE=simulation.""" | |
| with patch.dict(os.environ, {"OPENENV_CLIENT_MODE": "simulation"}): | |
| client = GenericEnvClient(base_url="http://localhost:8000") | |
| assert client._mode == "simulation" | |
| def test_env_var_production_mode(self): | |
| """Test mode selection via OPENENV_CLIENT_MODE=production.""" | |
| with patch.dict(os.environ, {"OPENENV_CLIENT_MODE": "production"}): | |
| client = GenericEnvClient(base_url="http://localhost:8000") | |
| assert client._mode == "production" | |
| def test_env_var_case_insensitive(self): | |
| """Test that OPENENV_CLIENT_MODE is case-insensitive.""" | |
| with patch.dict(os.environ, {"OPENENV_CLIENT_MODE": "PRODUCTION"}): | |
| client = GenericEnvClient(base_url="http://localhost:8000") | |
| assert client._mode == "production" | |
| def test_env_var_overrides_default(self): | |
| """Test that environment variable overrides default mode.""" | |
| with patch.dict(os.environ, {"OPENENV_CLIENT_MODE": "production"}): | |
| # No explicit mode in constructor | |
| client = GenericEnvClient(base_url="http://localhost:8000") | |
| assert client._mode == "production" | |
| def test_constructor_overrides_env_var(self): | |
| """Test that explicit constructor parameter overrides environment variable.""" | |
| with patch.dict(os.environ, {"OPENENV_CLIENT_MODE": "production"}): | |
| # Explicit mode in constructor should take precedence | |
| client = GenericEnvClient( | |
| base_url="http://localhost:8000", mode="simulation" | |
| ) | |
| assert client._mode == "simulation" | |
| def test_invalid_env_var_raises_error(self): | |
| """Test that invalid OPENENV_CLIENT_MODE raises ValueError.""" | |
| with patch.dict(os.environ, {"OPENENV_CLIENT_MODE": "invalid"}): | |
| with pytest.raises(ValueError) as exc_info: | |
| GenericEnvClient(base_url="http://localhost:8000") | |
| assert "OPENENV_CLIENT_MODE" in str(exc_info.value) | |
| assert "invalid" in str(exc_info.value).lower() | |
| # ============================================================================ | |
| # Mode Behavior Tests | |
| # ============================================================================ | |
| class TestModeBehavior: | |
| """Test that different modes result in different client behavior.""" | |
| async def test_simulation_mode_uses_gym_protocol(self, clean_env, mock_websocket): | |
| """Test that simulation mode uses Gym-style WebSocket messages.""" | |
| client = GenericEnvClient(base_url="http://localhost:8000", mode="simulation") | |
| with patch.object(client, "_send") as mock_send: | |
| with patch.object( | |
| client, | |
| "_receive", | |
| return_value={ | |
| "type": "response", | |
| "data": {"observation": {}, "reward": None, "done": False}, | |
| }, | |
| ): | |
| with patch.object(client, "_ws", mock_websocket): | |
| await client.reset() | |
| # Should send WSResetMessage format | |
| call_args = mock_send.call_args_list | |
| reset_call = [ | |
| call for call in call_args if call[0][0].get("type") == "reset" | |
| ] | |
| assert len(reset_call) > 0, ( | |
| "Should send reset message with type='reset'" | |
| ) | |
| async def test_production_mode_uses_jsonrpc_protocol( | |
| self, clean_env, mock_websocket | |
| ): | |
| """Test that production mode uses JSON-RPC format for tool calls.""" | |
| client = MCPToolClient(base_url="http://localhost:8000", mode="production") | |
| with patch.object(client, "_send") as mock_send: | |
| with patch.object( | |
| client, | |
| "_receive", | |
| return_value={ | |
| "type": "response", | |
| "data": { | |
| "observation": {"tools": []}, | |
| "reward": None, | |
| "done": False, | |
| }, | |
| }, | |
| ): | |
| with patch.object(client, "_ws", mock_websocket): | |
| await client.list_tools() | |
| # Should send step message with list_tools action | |
| call_args = mock_send.call_args_list | |
| step_call = [ | |
| call for call in call_args if call[0][0].get("type") == "step" | |
| ] | |
| assert len(step_call) > 0, "Should send message with type='step'" | |
| # Check that the action payload is list_tools | |
| step_message = step_call[0][0][0] | |
| assert "data" in step_message | |
| assert step_message["data"].get("type") == "list_tools" | |
| # ============================================================================ | |
| # Mode Immutability Tests | |
| # ============================================================================ | |
| class TestModeImmutability: | |
| """Test that mode cannot be changed after client creation.""" | |
| def test_mode_cannot_be_changed_after_creation(self, clean_env): | |
| """Test that mode attribute is read-only after initialization.""" | |
| client = GenericEnvClient(base_url="http://localhost:8000", mode="simulation") | |
| # Attempting to change mode should raise AttributeError or have no effect | |
| with pytest.raises((AttributeError, ValueError)): | |
| client._mode = "mcp" | |
| def test_mode_cannot_be_changed_after_connection(self, clean_env): | |
| """Test that mode cannot be changed after connection is established.""" | |
| client = GenericEnvClient(base_url="http://localhost:8000", mode="simulation") | |
| with patch.object(client, "_ws", MagicMock()): | |
| # Mark as connected | |
| client._ws = MagicMock() | |
| # Should not allow mode change | |
| with pytest.raises((AttributeError, ValueError)): | |
| client._mode = "mcp" | |
| # ============================================================================ | |
| # Cross-Client Mode Consistency Tests | |
| # ============================================================================ | |
| class TestCrossClientModeConsistency: | |
| """Test that mode selection works consistently across different client types.""" | |
| def test_generic_client_supports_both_modes(self, clean_env): | |
| """Test that GenericEnvClient supports both simulation and production modes.""" | |
| ws_client = GenericEnvClient( | |
| base_url="http://localhost:8000", mode="simulation" | |
| ) | |
| mcp_client = GenericEnvClient( | |
| base_url="http://localhost:8000", mode="production" | |
| ) | |
| assert ws_client._mode == "simulation" | |
| assert mcp_client._mode == "production" | |
| def test_mcp_client_defaults_to_production_mode(self, clean_env): | |
| """Test that MCPToolClient defaults to 'production' mode.""" | |
| client = MCPToolClient(base_url="http://localhost:8000") | |
| # MCPToolClient should default to production mode | |
| assert client._mode == "production" | |
| def test_mcp_client_cannot_use_simulation_mode(self, clean_env): | |
| """Test that MCPToolClient raises error if simulation mode is requested.""" | |
| with pytest.raises(ValueError) as exc_info: | |
| MCPToolClient(base_url="http://localhost:8000", mode="simulation") | |
| assert "MCPToolClient" in str(exc_info.value) | |
| assert "production" in str(exc_info.value).lower() | |
| # ============================================================================ | |
| # Mode Documentation Tests | |
| # ============================================================================ | |
| class TestModeDocumentation: | |
| """Test that mode parameter is properly documented.""" | |
| def test_mode_parameter_in_docstring(self, clean_env): | |
| """Test that mode parameter is documented in __init__ docstring.""" | |
| # GenericEnvClient should document mode parameter | |
| docstring = GenericEnvClient.__init__.__doc__ | |
| # Should mention mode in Args section | |
| assert docstring is not None | |
| assert "mode" in docstring.lower() | |
| def test_mode_values_documented(self, clean_env): | |
| """Test that valid mode values are documented.""" | |
| docstring = GenericEnvClient.__init__.__doc__ | |
| # Should document both simulation and production modes | |
| assert "simulation" in docstring.lower() | |
| assert "production" in docstring.lower() | |
| # ============================================================================ | |
| # Environment Code Mode Tests (from issue #347) | |
| # ============================================================================ | |
| class _TestMCPEnv(MCPEnvironment): | |
| """Concrete MCPEnvironment for testing with real FastMCP server.""" | |
| def __init__(self, mcp_server): | |
| super().__init__(mcp_server) | |
| self._state = State(episode_id="test", step_count=0) | |
| def reset(self, **kwargs): | |
| self._state = State(episode_id=kwargs.get("episode_id", "test"), step_count=0) | |
| return Observation(done=False, reward=0.0) | |
| def _step_impl(self, action, **kwargs): | |
| self._state.step_count += 1 | |
| return Observation(done=False, reward=0.0) | |
| def state(self): | |
| return self._state | |
| # ============================================================================= | |
| # Test Fixtures - Environment Code Mode | |
| # ============================================================================= | |
| def mcp_server_with_tools(): | |
| """Create a real FastMCP server with tools for testing.""" | |
| mcp = FastMCP("test-code-mode") | |
| def add(a: int, b: int) -> int: | |
| """Add two numbers.""" | |
| return a + b | |
| def multiply(x: int, y: int) -> int: | |
| """Multiply two numbers.""" | |
| return x * y | |
| return mcp | |
| # ============================================================================= | |
| # Code Mode Capability Tests | |
| # ============================================================================= | |
| class TestCodeModeCapability: | |
| """Tests for code mode capability detection.""" | |
| def test_environment_has_code_mode_capability(self, mcp_server_with_tools): | |
| """Test environment can report code mode support.""" | |
| env = _TestMCPEnv(mcp_server_with_tools) | |
| assert hasattr(env, "supports_code_mode") | |
| assert env.supports_code_mode is True | |
| # ============================================================================= | |
| # Code Mode Tests (with FastMCP Server) | |
| # ============================================================================= | |
| class TestCodeModeWithFastMCP: | |
| """Tests for code mode with real FastMCP servers.""" | |
| def test_get_callables_returns_tool_functions(self, mcp_server_with_tools): | |
| """Test get_callables() extracts functions from FastMCP server.""" | |
| env = _TestMCPEnv(mcp_server_with_tools) | |
| callables = env.get_callables() | |
| assert "add" in callables | |
| assert callable(callables["add"]) | |
| assert "multiply" in callables | |
| assert callable(callables["multiply"]) | |
| def test_callables_work_directly(self, mcp_server_with_tools): | |
| """Test callables from get_callables() can be called directly.""" | |
| env = _TestMCPEnv(mcp_server_with_tools) | |
| callables = env.get_callables() | |
| result = callables["add"](a=5, b=3) | |
| assert result == 8 | |
| def test_code_mode_executes_python_directly(self, mcp_server_with_tools): | |
| """Test code mode executes Python code with tools as direct callables.""" | |
| env = _TestMCPEnv(mcp_server_with_tools) | |
| env.reset() | |
| code = """ | |
| result = add(a=5, b=3) | |
| """ | |
| obs = env.execute_code(code) | |
| assert isinstance(obs, Observation) | |
| assert obs.metadata.get("result") == 8 | |
| def test_code_mode_multiple_tool_calls_in_one_step(self, mcp_server_with_tools): | |
| """Test code mode allows multiple tool calls in a single step.""" | |
| env = _TestMCPEnv(mcp_server_with_tools) | |
| env.reset() | |
| code = """ | |
| x = add(a=2, b=3) | |
| y = multiply(x=x, y=4) | |
| result = y | |
| """ | |
| obs = env.execute_code(code) | |
| # (2 + 3) * 4 = 20 | |
| assert obs.metadata.get("result") == 20 | |
| def test_code_mode_with_complex_python_logic(self, mcp_server_with_tools): | |
| """Test code mode supports arbitrary Python logic around tool calls.""" | |
| env = _TestMCPEnv(mcp_server_with_tools) | |
| env.reset() | |
| code = """ | |
| numbers = [1, 2, 3, 4, 5] | |
| total = 0 | |
| for n in numbers: | |
| total = add(a=total, b=n) | |
| result = total | |
| """ | |
| obs = env.execute_code(code) | |
| assert obs.metadata.get("result") == 15 # Sum of 1+2+3+4+5 | |
| # ============================================================================= | |
| # Code Mode with Mode-Aware Tools | |
| # ============================================================================= | |
| class TestCodeModeWithModeAwareTools: | |
| """Tests for code mode integration with mode-aware tool registration.""" | |
| def test_get_callables_includes_mode_specific_tools(self): | |
| """Test get_callables() returns mode-specific tools for current mode.""" | |
| mcp = FastMCP("mode-test") | |
| class ModeEnv(_TestMCPEnv): | |
| def __init__(self): | |
| super().__init__(mcp) | |
| self._mode = "simulation" | |
| def sim_tool(x: int) -> int: | |
| return x * 10 | |
| def prod_tool(x: int) -> int: | |
| return x * 100 | |
| env = ModeEnv() | |
| callables = env.get_callables() | |
| # In simulation mode, should have sim_tool but not prod_tool | |
| assert "sim_tool" in callables | |
| assert "prod_tool" not in callables | |
| assert callables["sim_tool"](x=5) == 50 | |
| def test_get_callables_switches_with_mode(self): | |
| """Test get_callables() returns different tools when mode changes.""" | |
| mcp = FastMCP("mode-switch-test") | |
| class ModeEnv(_TestMCPEnv): | |
| def __init__(self): | |
| super().__init__(mcp) | |
| self._mode = "simulation" | |
| def lookup(query: str) -> str: | |
| return f"sim:{query}" | |
| def lookup(query: str) -> str: # noqa: F811 | |
| return f"prod:{query}" | |
| env = ModeEnv() | |
| # In simulation mode | |
| callables_sim = env.get_callables() | |
| assert callables_sim["lookup"](query="test") == "sim:test" | |
| # Switch to production mode | |
| env._mode = "production" | |
| callables_prod = env.get_callables() | |
| assert callables_prod["lookup"](query="test") == "prod:test" | |
| def test_execute_code_uses_mode_specific_tools(self): | |
| """Test execute_code() uses the correct mode-specific tools.""" | |
| mcp = FastMCP("code-mode-test") | |
| class ModeEnv(_TestMCPEnv): | |
| def __init__(self): | |
| super().__init__(mcp) | |
| self._mode = "simulation" | |
| def compute(x: int) -> int: | |
| return x + 1 | |
| def compute(x: int) -> int: # noqa: F811 | |
| return x + 1000 | |
| env = ModeEnv() | |
| # In simulation mode | |
| obs = env.execute_code("result = compute(x=5)") | |
| assert obs.metadata.get("result") == 6 | |
| # Switch to production mode | |
| env._mode = "production" | |
| obs = env.execute_code("result = compute(x=5)") | |
| assert obs.metadata.get("result") == 1005 | |
| # ============================================================================= | |
| # Tool-Calling Mode Tests (Backwards Compatibility) | |
| # ============================================================================= | |
| class TestToolCallingMode: | |
| """Tests that tool-calling mode still works (backwards compatibility).""" | |
| def test_list_tools_still_works(self, mcp_server_with_tools): | |
| """Test ListToolsAction still works in tool-calling mode.""" | |
| env = _TestMCPEnv(mcp_server_with_tools) | |
| action = ListToolsAction() | |
| obs = env.step(action) | |
| assert isinstance(obs, ListToolsObservation) | |
| assert len(obs.tools) > 0 | |
| def test_code_mode_preserves_tool_schemas_for_discovery( | |
| self, mcp_server_with_tools | |
| ): | |
| """Test code mode doesn't break tool discovery (list_tools still works).""" | |
| env = _TestMCPEnv(mcp_server_with_tools) | |
| # Tool discovery should still work via step() | |
| obs = env.step(ListToolsAction()) | |
| assert isinstance(obs, ListToolsObservation) | |
| assert len(obs.tools) > 0 | |
| # And also via get_callables() for code mode | |
| callables = env.get_callables() | |
| assert len(callables) == len(obs.tools) | |
| # ============================================================================= | |
| # Error Handling Tests | |
| # ============================================================================= | |
| class TestCodeModeErrorHandling: | |
| """Tests for error handling in code mode.""" | |
| def test_code_mode_handles_syntax_errors(self, mcp_server_with_tools): | |
| """Test code mode returns proper error for Python syntax errors.""" | |
| env = _TestMCPEnv(mcp_server_with_tools) | |
| env.reset() | |
| code = """ | |
| result = add(a=5, b= # Syntax error | |
| """ | |
| obs = env.execute_code(code) | |
| assert obs.metadata.get("error") is not None | |
| assert "syntax" in obs.metadata["error"].lower() | |
| def test_code_mode_handles_runtime_errors(self, mcp_server_with_tools): | |
| """Test code mode returns proper error for runtime errors.""" | |
| env = _TestMCPEnv(mcp_server_with_tools) | |
| env.reset() | |
| code = """ | |
| result = add(a=5, b="not a number") # Type error | |
| """ | |
| obs = env.execute_code(code) | |
| assert obs.metadata.get("error") is not None | |
| def test_code_mode_handles_missing_tool(self, mcp_server_with_tools): | |
| """Test code mode returns proper error when calling non-existent tool.""" | |
| env = _TestMCPEnv(mcp_server_with_tools) | |
| env.reset() | |
| code = """ | |
| result = nonexistent_tool(x=1) | |
| """ | |
| obs = env.execute_code(code) | |
| assert obs.metadata.get("error") is not None | |
| assert "nonexistent_tool" in obs.metadata["error"] | |
| # ============================================================================= | |
| # Integration Tests | |
| # ============================================================================= | |
| class TestCodeModeIntegration: | |
| """Integration tests for code mode with real MCP servers.""" | |
| def test_echo_env_in_code_mode(self): | |
| """Test EchoEnvironment supports code mode.""" | |
| from echo_env.server.echo_environment import EchoEnvironment | |
| env = EchoEnvironment() | |
| env.reset() | |
| code = """ | |
| msg = echo_message(message="Hello from code mode!") | |
| result = msg | |
| """ | |
| obs = env.execute_code(code) | |
| assert "Hello from code mode!" in str(obs.metadata.get("result")) | |