Spaces:
Sleeping
Sleeping
| """WebSocket smoke test: spin up the FastAPI server in-process and drive it | |
| through :class:`physix.PhysiXEnv` over a real WebSocket connection. | |
| This catches regressions in the wire protocol (action/observation | |
| serialisation, session lifecycle) that the in-process | |
| ``test_environment.py`` cannot. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import contextlib | |
| import socket | |
| import threading | |
| import time | |
| from collections.abc import Iterator | |
| import pytest | |
| import uvicorn | |
| from physix.client import PhysiXEnv | |
| from physix.models import PhysiXAction | |
| from physix.server.app import app | |
| # --------------------------------------------------------------------------- | |
| # Server fixture | |
| # --------------------------------------------------------------------------- | |
| def _free_port() -> int: | |
| """Return an OS-assigned free TCP port.""" | |
| with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: | |
| sock.bind(("127.0.0.1", 0)) | |
| return int(sock.getsockname()[1]) | |
| def server_url() -> Iterator[str]: | |
| """Run uvicorn in a daemon thread for the duration of the module.""" | |
| port = _free_port() | |
| config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning") | |
| server = uvicorn.Server(config) | |
| thread = threading.Thread(target=server.run, daemon=True) | |
| thread.start() | |
| deadline = time.time() + 10.0 | |
| while not server.started and time.time() < deadline: | |
| time.sleep(0.05) | |
| if not server.started: | |
| pytest.fail("uvicorn server failed to start within timeout") | |
| try: | |
| yield f"http://127.0.0.1:{port}" | |
| finally: | |
| server.should_exit = True | |
| thread.join(timeout=5.0) | |
| # --------------------------------------------------------------------------- | |
| # Test | |
| # --------------------------------------------------------------------------- | |
| async def _drive_episode(base_url: str) -> None: | |
| """One reset + step against the live server, asserting reward shape.""" | |
| async with PhysiXEnv(base_url=base_url) as env: | |
| result = await env.reset(system_id="free_fall", seed=11) | |
| assert result.done is False | |
| assert result.observation.system_id == "free_fall" | |
| assert result.observation.turn == 0 | |
| assert len(result.observation.trajectory) > 0 | |
| result = await env.step( | |
| PhysiXAction(equation="d2y/dt2 = -9.81", params={}, rationale="free fall") | |
| ) | |
| breakdown = result.observation.reward_breakdown | |
| assert breakdown["format"] == 1.0 | |
| assert breakdown["match"] >= 0.9 | |
| assert result.done is True | |
| def test_websocket_round_trip(server_url: str) -> None: | |
| asyncio.run(_drive_episode(server_url)) | |