Spaces:
Sleeping
Sleeping
File size: 2,749 Bytes
0e24aff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | """WebSocket smoke test: spin up the FastAPI server in-process and drive it
through :class:`physix.PhysiXEnv` over a real WebSocket connection.
This catches regressions in the wire protocol (action/observation
serialisation, session lifecycle) that the in-process
``test_environment.py`` cannot.
"""
from __future__ import annotations
import asyncio
import contextlib
import socket
import threading
import time
from collections.abc import Iterator
import pytest
import uvicorn
from physix.client import PhysiXEnv
from physix.models import PhysiXAction
from physix.server.app import app
# ---------------------------------------------------------------------------
# Server fixture
# ---------------------------------------------------------------------------
def _free_port() -> int:
"""Return an OS-assigned free TCP port."""
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.bind(("127.0.0.1", 0))
return int(sock.getsockname()[1])
@pytest.fixture(scope="module")
def server_url() -> Iterator[str]:
"""Run uvicorn in a daemon thread for the duration of the module."""
port = _free_port()
config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning")
server = uvicorn.Server(config)
thread = threading.Thread(target=server.run, daemon=True)
thread.start()
deadline = time.time() + 10.0
while not server.started and time.time() < deadline:
time.sleep(0.05)
if not server.started:
pytest.fail("uvicorn server failed to start within timeout")
try:
yield f"http://127.0.0.1:{port}"
finally:
server.should_exit = True
thread.join(timeout=5.0)
# ---------------------------------------------------------------------------
# Test
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def _drive_episode(base_url: str) -> None:
"""One reset + step against the live server, asserting reward shape."""
async with PhysiXEnv(base_url=base_url) as env:
result = await env.reset(system_id="free_fall", seed=11)
assert result.done is False
assert result.observation.system_id == "free_fall"
assert result.observation.turn == 0
assert len(result.observation.trajectory) > 0
result = await env.step(
PhysiXAction(equation="d2y/dt2 = -9.81", params={}, rationale="free fall")
)
breakdown = result.observation.reward_breakdown
assert breakdown["format"] == 1.0
assert breakdown["match"] >= 0.9
assert result.done is True
def test_websocket_round_trip(server_url: str) -> None:
asyncio.run(_drive_episode(server_url))
|