# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Unit tests for Unity ML-Agents environment server. ============================================================================= HOW TO RUN THESE TESTS ============================================================================= Running Tests: # From the OpenEnv repository root directory: # Run all Unity environment tests pytest tests/envs/test_unity_environment.py -v # Run with longer timeout (recommended for first run - downloads ~500MB binaries) pytest tests/envs/test_unity_environment.py -v --timeout=300 # Run with print output visible pytest tests/envs/test_unity_environment.py -v -s ============================================================================= """ import os import subprocess import sys import time import pytest import requests # Add the project root to the path for envs imports sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) # Check if mlagents-envs is installed try: import mlagents_envs # noqa: F401 MLAGENTS_INSTALLED = True except ImportError: MLAGENTS_INSTALLED = False # Skip all tests if mlagents-envs is not installed pytestmark = pytest.mark.skipif( not MLAGENTS_INSTALLED, reason="mlagents-envs not installed" ) from envs.unity_env.client import UnityEnv from envs.unity_env.models import UnityAction, UnityObservation, UnityState @pytest.fixture(scope="module") def server(): """Starts the Unity environment server as a background process. Note: Unity environments can take 30-120 seconds to initialize on first run due to binary downloads (~500MB). Subsequent runs use cached binaries. """ # Define paths for subprocess environment ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) SRC_PATH = os.path.join(ROOT_DIR, "src") PORT = 8011 # Use a unique port to avoid conflicts localhost = f"http://localhost:{PORT}" print(f"\n--- Starting Unity ML-Agents server on port {PORT} ---") server_env = { **os.environ, "PYTHONPATH": f"{SRC_PATH}:{ROOT_DIR}", "UNITY_NO_GRAPHICS": "1", # Run headless for testing "UNITY_TIME_SCALE": "20", # Speed up for faster tests # Bypass proxy for localhost "NO_PROXY": "localhost,127.0.0.1", "no_proxy": "localhost,127.0.0.1", } # Use uvicorn directly instead of gunicorn for simpler setup uvicorn_command = [ sys.executable, "-m", "uvicorn", "envs.unity_env.server.app:app", "--host", "0.0.0.0", "--port", str(PORT), ] # Create a log file for server output log_file = os.path.join(ROOT_DIR, "tests", "unity_server_test.log") log_handle = open(log_file, "w") server_process = subprocess.Popen( uvicorn_command, env=server_env, stdout=log_handle, stderr=subprocess.STDOUT, text=True, cwd=ROOT_DIR, ) # Wait for server to become healthy # Note: Initial startup is quick, but first reset() will download binaries print("\n--- Waiting for server to become healthy... ---") time.sleep(2) # Give server time to fully initialize # Bypass proxy for localhost requests no_proxy = {"http": None, "https": None} is_healthy = False for i in range(12): try: response = requests.get(f"{localhost}/health", timeout=5, proxies=no_proxy) if response.status_code == 200: is_healthy = True print("✅ Server is running and healthy!") break except requests.exceptions.RequestException as e: print(f"Attempt {i + 1}/12: Server not ready ({e}), waiting 5 seconds...") time.sleep(5) if not is_healthy: print("❌ Server did not become healthy in time. Aborting.") print("\n--- Server Logs ---") server_process.kill() log_handle.close() with open(log_file, "r") as f: print(f.read()) pytest.skip("Server failed to start") yield localhost # Cleanup print("\n--- Cleaning up server ---") try: server_process.terminate() server_process.wait(timeout=10) print("✅ Server process terminated") except subprocess.TimeoutExpired: server_process.kill() print("✅ Server process killed") except ProcessLookupError: print("✅ Server process was already terminated") finally: log_handle.close() class TestHealthEndpoint: """Tests for the health endpoint.""" def test_health_endpoint_returns_200(self, server): """Test that the health endpoint returns 200 OK.""" response = requests.get( f"{server}/health", proxies={"http": None, "https": None} ) assert response.status_code == 200 def test_health_endpoint_returns_status(self, server): """Test that the health endpoint returns status field.""" response = requests.get( f"{server}/health", proxies={"http": None, "https": None} ) data = response.json() assert "status" in data assert data["status"] == "healthy" class TestUnityEnvClient: """Tests for the UnityEnv client.""" # Note: This test may take up to 3 minutes on first run (binary download) def test_reset_returns_valid_observation(self, server): """Test that reset() returns a valid observation.""" with UnityEnv(base_url=server) as env: result = env.reset(env_id="PushBlock") assert result is not None assert result.observation is not None assert isinstance(result.observation, UnityObservation) assert hasattr(result.observation, "vector_observations") assert hasattr(result.observation, "behavior_name") assert hasattr(result.observation, "action_spec_info") assert result.observation.done is False def test_reset_with_different_environments(self, server): """Test that reset() can switch between environments.""" with UnityEnv(base_url=server) as env: # Reset to PushBlock result1 = env.reset(env_id="PushBlock") assert result1.observation.behavior_name is not None assert "Push" in result1.observation.behavior_name # Reset to 3DBall result2 = env.reset(env_id="3DBall") assert result2.observation.behavior_name is not None assert "3DBall" in result2.observation.behavior_name def test_step_discrete_action(self, server): """Test that step() works with discrete actions (PushBlock).""" with UnityEnv(base_url=server) as env: env.reset(env_id="PushBlock") # PushBlock has 7 discrete actions (0-6) action = UnityAction(discrete_actions=[1]) # Move forward result = env.step(action) assert result is not None assert result.observation is not None assert isinstance(result.reward, (int, float)) or result.reward is None assert isinstance(result.done, bool) def test_step_continuous_action(self, server): """Test that step() works with continuous actions (3DBall).""" with UnityEnv(base_url=server) as env: env.reset(env_id="3DBall") # 3DBall has 2 continuous actions action = UnityAction(continuous_actions=[0.5, -0.3]) result = env.step(action) assert result is not None assert result.observation is not None assert isinstance(result.reward, (int, float)) or result.reward is None assert isinstance(result.done, bool) def test_step_multiple_times(self, server): """Test that step() can be called multiple times.""" with UnityEnv(base_url=server) as env: env.reset(env_id="PushBlock") for i in range(10): action = UnityAction(discrete_actions=[i % 7]) result = env.step(action) assert result.observation is not None def test_state_endpoint(self, server): """Test that state() returns valid state information.""" with UnityEnv(base_url=server) as env: env.reset(env_id="PushBlock") state = env.state() assert state is not None assert isinstance(state, UnityState) assert hasattr(state, "env_id") assert hasattr(state, "episode_id") assert hasattr(state, "step_count") assert hasattr(state, "behavior_name") assert hasattr(state, "action_spec") assert state.env_id == "PushBlock" def test_step_count_increments(self, server): """Test that step count increments correctly.""" with UnityEnv(base_url=server) as env: env.reset(env_id="PushBlock") state1 = env.state() assert state1.step_count == 0 action = UnityAction(discrete_actions=[1]) env.step(action) state2 = env.state() assert state2.step_count == 1 env.step(action) state3 = env.state() assert state3.step_count == 2 def test_reset_resets_step_count(self, server): """Test that reset() resets the step count.""" with UnityEnv(base_url=server) as env: env.reset(env_id="PushBlock") # Take some steps action = UnityAction(discrete_actions=[1]) for _ in range(5): env.step(action) state1 = env.state() assert state1.step_count == 5 # Reset env.reset(env_id="PushBlock") state2 = env.state() assert state2.step_count == 0 def test_episode_id_changes_on_reset(self, server): """Test that episode ID changes on each reset.""" with UnityEnv(base_url=server) as env: env.reset(env_id="PushBlock") state1 = env.state() env.reset(env_id="PushBlock") state2 = env.state() assert state1.episode_id != state2.episode_id def test_action_spec_info(self, server): """Test that action spec info is provided correctly.""" with UnityEnv(base_url=server) as env: # PushBlock - discrete actions result = env.reset(env_id="PushBlock") action_spec = result.observation.action_spec_info assert action_spec is not None assert action_spec.get("is_discrete") is True assert action_spec.get("discrete_size") == 1 assert len(action_spec.get("discrete_branches", [])) > 0 # 3DBall - continuous actions result = env.reset(env_id="3DBall") action_spec = result.observation.action_spec_info assert action_spec is not None assert action_spec.get("is_continuous") is True assert action_spec.get("continuous_size") == 2 class TestUnityEnvModels: """Tests for Unity environment models.""" def test_unity_action_discrete(self): """Test creating a discrete UnityAction.""" action = UnityAction(discrete_actions=[1, 2, 3]) assert action.discrete_actions == [1, 2, 3] assert action.continuous_actions is None def test_unity_action_continuous(self): """Test creating a continuous UnityAction.""" action = UnityAction(continuous_actions=[0.5, -0.3, 1.0]) assert action.continuous_actions == [0.5, -0.3, 1.0] assert action.discrete_actions is None def test_unity_action_with_metadata(self): """Test creating a UnityAction with metadata.""" action = UnityAction( discrete_actions=[1], metadata={"test": "value", "number": 42} ) assert action.discrete_actions == [1] assert action.metadata == {"test": "value", "number": 42} def test_unity_observation_creation(self): """Test creating a UnityObservation.""" obs = UnityObservation( vector_observations=[1.0, 2.0, 3.0], behavior_name="TestBehavior", done=False, reward=0.5, action_spec_info={"is_discrete": True}, observation_spec_info={"count": 1}, ) assert obs.vector_observations == [1.0, 2.0, 3.0] assert obs.behavior_name == "TestBehavior" assert obs.done is False assert obs.reward == 0.5 def test_unity_state_creation(self): """Test creating a UnityState.""" state = UnityState( episode_id="test-episode-123", step_count=10, env_id="PushBlock", behavior_name="PushBlockBehavior", action_spec={"is_discrete": True}, observation_spec={"count": 1}, available_envs=["PushBlock", "3DBall"], ) assert state.episode_id == "test-episode-123" assert state.step_count == 10 assert state.env_id == "PushBlock" assert state.available_envs == ["PushBlock", "3DBall"] class TestAvailableEnvironments: """Tests for available environments functionality.""" def test_available_environments_static_method(self): """Test the static available_environments method.""" envs = UnityEnv.available_environments() assert isinstance(envs, list) assert "PushBlock" in envs assert "3DBall" in envs def test_available_envs_from_state(self, server): """Test getting available environments from state.""" with UnityEnv(base_url=server) as env: env.reset(env_id="PushBlock") state = env.state() assert state.available_envs is not None assert isinstance(state.available_envs, list) assert len(state.available_envs) > 0 assert "PushBlock" in state.available_envs assert "3DBall" in state.available_envs