Parlay / smoke_test.py
sh4shv4t's picture
fix: resolve WebSocket HTTP 403 on OpenEnv env server
f33ad7b
import asyncio
import json
import sys
import websockets
# Set from Step 1: router prefix "/env" + "/ws" => "/env/ws" on the env server process
WS_URL = "ws://localhost:8001/env/ws"
def _observation(data: dict) -> dict:
"""Reset/step responses wrap the observation; unwrap for assertions."""
if "observation" in data and isinstance(data["observation"], dict):
return data["observation"]
if "error" in data:
raise AssertionError(f"server error: {data.get('error')}")
return data
async def smoke_test() -> None:
print(f"Connecting to {WS_URL}...")
try:
async with websockets.connect(
WS_URL,
open_timeout=8,
additional_headers={"Origin": "http://localhost:8000"},
) as ws:
# --- Test 1: reset() ---
print("Testing reset()...")
await ws.send(
json.dumps(
{
"method": "reset",
"params": {
"persona": "shark",
"scenario_id": "saas_enterprise",
},
}
)
)
raw = await asyncio.wait_for(ws.recv(), timeout=10)
data0 = json.loads(raw)
session_id = data0.get("session_id")
if not session_id and "error" in data0:
raise AssertionError(f"reset failed: {data0!r}")
assert session_id, f"reset() missing session_id. Got: {data0!r}"
obs = _observation(data0)
print(f" reset() response keys: {sorted(data0.keys())}")
# API uses step_count (not turn_count) on the observation
assert "step_count" in obs, f"reset() missing step_count. Got: {list(obs.keys())}"
print(
f" PASS reset() | step_count={obs.get('step_count')} "
f"tension_score={obs.get('tension_score')} "
f"cumulative_reward={obs.get('cumulative_reward', 'N/A')}"
)
# --- Test 2: step() ---
print("Testing step()...")
# Use an offer above the scenario ceiling so the episode does not
# immediately satisfy deal_reached (else step 2 would get "already done").
await ws.send(
json.dumps(
{
"method": "step",
"params": {
"session_id": session_id,
"utterance": "We are prepared to offer 500000 for the annual contract (probe).",
"offer_amount": 500_000.0,
"tactical_move": None,
},
}
)
)
raw2 = await asyncio.wait_for(ws.recv(), timeout=20)
data1 = json.loads(raw2)
if "error" in data1 and "observation" not in data1:
raise AssertionError(f"step failed: {data1!r}")
obs2 = _observation(data1)
print(f" step() envelope keys: {sorted(data1.keys())}")
print(f" step() observation keys: {sorted(obs2.keys())}")
assert "reward" in obs2, f"step() missing reward. Got: {list(obs2.keys())}"
print(
f" PASS step() | reward={obs2.get('reward')} "
f"step_count={obs2.get('step_count')} "
f"tension_score={obs2.get('tension_score')}"
)
# --- Test 3: state() ---
print("Testing state()...")
await ws.send(
json.dumps({"method": "state", "params": {"session_id": session_id}})
)
raw3 = await asyncio.wait_for(ws.recv(), timeout=8)
st = json.loads(raw3)
state = st.get("state", st) if isinstance(st, dict) else st
print(f" state() response keys: {sorted(st.keys()) if isinstance(st, dict) else st}")
assert isinstance(state, dict) and len(state) > 0, "state() returned empty"
print(f" PASS state() | {len(state)} keys in state payload")
# --- Test 4: ZOPA erosion fields in observation ---
print("Testing ZOPA erosion fields...")
has_erosion = "zopa_erosion_ticks" in obs2 or "zopa_width_pct_remaining" in obs2
if has_erosion:
print(
f" PASS ZOPA erosion fields present | "
f"ticks={obs2.get('zopa_erosion_ticks', 0)} "
f"width_pct={obs2.get('zopa_width_pct_remaining', 1.0):.2f}"
)
else:
print(f" WARN ZOPA erosion fields NOT in obs2 keys: {sorted(obs2.keys())}")
print(
" This means the ZOPA collapse mechanic is not "
"surfaced in the observation model."
)
# --- Test 5: reward variance (not always 0.0) ---
print("Testing reward is non-trivial...")
reward_val = obs2.get("reward")
assert isinstance(reward_val, (int, float)), f"reward is not numeric: {type(reward_val)}"
print(f" PASS reward is numeric: {reward_val}")
if reward_val == 0.0:
print(
" WARN reward is exactly 0.0 — step rewards may not "
"be accumulating. Check grader.py compute_step_reward."
)
# --- Test 6: last utterance in observation (player line echoed in this env) ---
print("Testing utterance in observation...")
last_u = obs2.get("last_utterance", "")
if last_u:
preview = (last_u[:80] + "...") if len(last_u) > 80 else last_u
print(f" PASS last_utterance set: '{preview}'")
else:
print(" WARN last_utterance empty in obs2")
# --- Test 7: second step to verify turn increments ---
print("Testing turn increment on second step...")
await ws.send(
json.dumps(
{
"method": "step",
"params": {
"session_id": session_id,
"utterance": "I understand your position. We could go to 501000 (still probing).",
"offer_amount": 501_000.0,
"tactical_move": None,
},
}
)
)
raw4 = await asyncio.wait_for(ws.recv(), timeout=20)
data2 = json.loads(raw4)
obs3 = _observation(data2)
turn2 = obs3.get("step_count", -1)
turn1 = obs2.get("step_count", -1)
assert turn2 > turn1, f"Step count did not increment: was {turn1}, now {turn2}"
print(f" PASS Step incremented: {turn1} -> {turn2}")
print()
print("=" * 55)
print(" ALL SMOKE TESTS PASSED")
print(" The WebSocket env is working correctly.")
print("=" * 55)
print()
print("QUICK STATS:")
print(f" WebSocket URL: {WS_URL}")
print(f" Reward at step 1: {obs2.get('reward')}")
print(f" Reward at step 2: {obs3.get('reward')}")
print(f" Tension (step 2): {obs3.get('tension_score')}")
print(f" Step count: {turn2}")
except ConnectionRefusedError:
print()
print("FAIL ConnectionRefused — server is not running on this port.")
print(f" Tried: {WS_URL}")
print(" Start the env server first, then rerun this test.")
sys.exit(1)
except websockets.exceptions.InvalidStatus as e:
print()
print(f"FAIL Server rejected connection: {e}")
code = getattr(e, "response", None)
status = code.status_code if code is not None and hasattr(code, "status_code") else "unknown"
print(f" HTTP status: {status}")
print(" This is the 403 bug — wrong path or host policy.")
print(" Confirm WS_URL is ws://localhost:PORT/env/ws (not /ws).")
sys.exit(1)
except AssertionError as e:
print()
print(f"FAIL Assertion failed: {e}")
sys.exit(1)
except asyncio.TimeoutError:
print()
print("FAIL Timeout — server connected but did not respond in time.")
print(" The Gemini API call in step() may be hanging.")
print(" Check GOOGLE_API_KEY is set, or that mock mode is active.")
sys.exit(1)
except Exception as e:
print()
print(f"FAIL Unexpected error: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
asyncio.run(smoke_test())