| import asyncio |
| import json |
| import sys |
|
|
| import websockets |
|
|
| |
| WS_URL = "ws://localhost:8001/env/ws" |
|
|
|
|
| def _observation(data: dict) -> dict: |
| """Reset/step responses wrap the observation; unwrap for assertions.""" |
| if "observation" in data and isinstance(data["observation"], dict): |
| return data["observation"] |
| if "error" in data: |
| raise AssertionError(f"server error: {data.get('error')}") |
| return data |
|
|
|
|
| async def smoke_test() -> None: |
| print(f"Connecting to {WS_URL}...") |
| try: |
| async with websockets.connect( |
| WS_URL, |
| open_timeout=8, |
| additional_headers={"Origin": "http://localhost:8000"}, |
| ) as ws: |
| |
| print("Testing reset()...") |
| await ws.send( |
| json.dumps( |
| { |
| "method": "reset", |
| "params": { |
| "persona": "shark", |
| "scenario_id": "saas_enterprise", |
| }, |
| } |
| ) |
| ) |
| raw = await asyncio.wait_for(ws.recv(), timeout=10) |
| data0 = json.loads(raw) |
| session_id = data0.get("session_id") |
| if not session_id and "error" in data0: |
| raise AssertionError(f"reset failed: {data0!r}") |
| assert session_id, f"reset() missing session_id. Got: {data0!r}" |
| obs = _observation(data0) |
| print(f" reset() response keys: {sorted(data0.keys())}") |
| |
| assert "step_count" in obs, f"reset() missing step_count. Got: {list(obs.keys())}" |
| print( |
| f" PASS reset() | step_count={obs.get('step_count')} " |
| f"tension_score={obs.get('tension_score')} " |
| f"cumulative_reward={obs.get('cumulative_reward', 'N/A')}" |
| ) |
|
|
| |
| print("Testing step()...") |
| |
| |
| await ws.send( |
| json.dumps( |
| { |
| "method": "step", |
| "params": { |
| "session_id": session_id, |
| "utterance": "We are prepared to offer 500000 for the annual contract (probe).", |
| "offer_amount": 500_000.0, |
| "tactical_move": None, |
| }, |
| } |
| ) |
| ) |
| raw2 = await asyncio.wait_for(ws.recv(), timeout=20) |
| data1 = json.loads(raw2) |
| if "error" in data1 and "observation" not in data1: |
| raise AssertionError(f"step failed: {data1!r}") |
| obs2 = _observation(data1) |
| print(f" step() envelope keys: {sorted(data1.keys())}") |
| print(f" step() observation keys: {sorted(obs2.keys())}") |
| assert "reward" in obs2, f"step() missing reward. Got: {list(obs2.keys())}" |
| print( |
| f" PASS step() | reward={obs2.get('reward')} " |
| f"step_count={obs2.get('step_count')} " |
| f"tension_score={obs2.get('tension_score')}" |
| ) |
|
|
| |
| print("Testing state()...") |
| await ws.send( |
| json.dumps({"method": "state", "params": {"session_id": session_id}}) |
| ) |
| raw3 = await asyncio.wait_for(ws.recv(), timeout=8) |
| st = json.loads(raw3) |
| state = st.get("state", st) if isinstance(st, dict) else st |
| print(f" state() response keys: {sorted(st.keys()) if isinstance(st, dict) else st}") |
| assert isinstance(state, dict) and len(state) > 0, "state() returned empty" |
| print(f" PASS state() | {len(state)} keys in state payload") |
|
|
| |
| print("Testing ZOPA erosion fields...") |
| has_erosion = "zopa_erosion_ticks" in obs2 or "zopa_width_pct_remaining" in obs2 |
| if has_erosion: |
| print( |
| f" PASS ZOPA erosion fields present | " |
| f"ticks={obs2.get('zopa_erosion_ticks', 0)} " |
| f"width_pct={obs2.get('zopa_width_pct_remaining', 1.0):.2f}" |
| ) |
| else: |
| print(f" WARN ZOPA erosion fields NOT in obs2 keys: {sorted(obs2.keys())}") |
| print( |
| " This means the ZOPA collapse mechanic is not " |
| "surfaced in the observation model." |
| ) |
|
|
| |
| print("Testing reward is non-trivial...") |
| reward_val = obs2.get("reward") |
| assert isinstance(reward_val, (int, float)), f"reward is not numeric: {type(reward_val)}" |
| print(f" PASS reward is numeric: {reward_val}") |
| if reward_val == 0.0: |
| print( |
| " WARN reward is exactly 0.0 — step rewards may not " |
| "be accumulating. Check grader.py compute_step_reward." |
| ) |
|
|
| |
| print("Testing utterance in observation...") |
| last_u = obs2.get("last_utterance", "") |
| if last_u: |
| preview = (last_u[:80] + "...") if len(last_u) > 80 else last_u |
| print(f" PASS last_utterance set: '{preview}'") |
| else: |
| print(" WARN last_utterance empty in obs2") |
|
|
| |
| print("Testing turn increment on second step...") |
| await ws.send( |
| json.dumps( |
| { |
| "method": "step", |
| "params": { |
| "session_id": session_id, |
| "utterance": "I understand your position. We could go to 501000 (still probing).", |
| "offer_amount": 501_000.0, |
| "tactical_move": None, |
| }, |
| } |
| ) |
| ) |
| raw4 = await asyncio.wait_for(ws.recv(), timeout=20) |
| data2 = json.loads(raw4) |
| obs3 = _observation(data2) |
| turn2 = obs3.get("step_count", -1) |
| turn1 = obs2.get("step_count", -1) |
| assert turn2 > turn1, f"Step count did not increment: was {turn1}, now {turn2}" |
| print(f" PASS Step incremented: {turn1} -> {turn2}") |
|
|
| print() |
| print("=" * 55) |
| print(" ALL SMOKE TESTS PASSED") |
| print(" The WebSocket env is working correctly.") |
| print("=" * 55) |
| print() |
| print("QUICK STATS:") |
| print(f" WebSocket URL: {WS_URL}") |
| print(f" Reward at step 1: {obs2.get('reward')}") |
| print(f" Reward at step 2: {obs3.get('reward')}") |
| print(f" Tension (step 2): {obs3.get('tension_score')}") |
| print(f" Step count: {turn2}") |
|
|
| except ConnectionRefusedError: |
| print() |
| print("FAIL ConnectionRefused — server is not running on this port.") |
| print(f" Tried: {WS_URL}") |
| print(" Start the env server first, then rerun this test.") |
| sys.exit(1) |
| except websockets.exceptions.InvalidStatus as e: |
| print() |
| print(f"FAIL Server rejected connection: {e}") |
| code = getattr(e, "response", None) |
| status = code.status_code if code is not None and hasattr(code, "status_code") else "unknown" |
| print(f" HTTP status: {status}") |
| print(" This is the 403 bug — wrong path or host policy.") |
| print(" Confirm WS_URL is ws://localhost:PORT/env/ws (not /ws).") |
| sys.exit(1) |
| except AssertionError as e: |
| print() |
| print(f"FAIL Assertion failed: {e}") |
| sys.exit(1) |
| except asyncio.TimeoutError: |
| print() |
| print("FAIL Timeout — server connected but did not respond in time.") |
| print(" The Gemini API call in step() may be hanging.") |
| print(" Check GOOGLE_API_KEY is set, or that mock mode is active.") |
| sys.exit(1) |
| except Exception as e: |
| print() |
| print(f"FAIL Unexpected error: {type(e).__name__}: {e}") |
| import traceback |
|
|
| traceback.print_exc() |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| asyncio.run(smoke_test()) |
|
|