claims-env / test_websocket.py
akhiilll's picture
Deploy ClaimSense adjudication gym
1cfeb15 verified
#!/usr/bin/env python3
"""Smoke test that drives the gym through a full episode over WebSocket.
Run::
python test_websocket.py # talk to a local uvicorn
CLAIMS_ENV_WS=wss://… python ... # against the deployed Space
Prints a one-line summary per step and asserts on the basics (reset
returns a claim, terminal verdict produces a reward).
"""
from __future__ import annotations
import asyncio
import json
import os
import sys
import websockets
WS_URL = os.environ.get("CLAIMS_ENV_WS", "ws://127.0.0.1:7860/ws")
DIVIDER = "=" * 60
async def _exchange(ws, message: dict) -> dict:
await ws.send(json.dumps(message))
return json.loads(await ws.recv())
async def _step(ws, action_type: str, **parameters) -> dict:
"""Send one step and return the response payload."""
payload = await _exchange(
ws,
{
"type": "step",
"data": {"action_type": action_type, "parameters": parameters},
},
)
return payload
async def run_episode() -> int:
print(DIVIDER)
print(f"ClaimSense WebSocket smoke test β†’ {WS_URL}")
print(DIVIDER)
async with websockets.connect(WS_URL) as ws:
# ------------------------------------------------------------- reset
reply = await _exchange(ws, {"type": "reset", "data": {}})
if reply.get("type") == "error":
print(f"reset failed: {reply['data']}")
return 1
obs = reply["data"]["observation"]
claim_amount = float(obs["claim_amount_requested"])
print("\n[1] reset")
print(f" claim_id = {obs['claim_id']}")
print(f" claim_type = {obs['claim_type']}")
print(f" claim_amount = ${claim_amount:,.2f}")
print(f" description = {obs['description'][:80]}…")
# ------------------------------------------------------ query_policy
reply = await _step(ws, "query_policy")
print("\n[2] query_policy β†’ "
f"{reply['data']['observation']['system_response'][:100]}…")
# -------------------------------------------------------- check_fraud
reply = await _step(ws, "check_fraud")
obs = reply["data"]["observation"]
fraud = obs["revealed_info"].get("fraud_analysis", {})
score = float(fraud.get("risk_score", 0))
print(f"\n[3] check_fraud β†’ risk_score={score:.2f} "
f"({fraud.get('recommendation', '?')})")
# ----------------------------------------------------- verify_purchase
reply = await _step(ws, "verify_purchase")
print("\n[4] verify_purchase β†’ "
f"{reply['data']['observation']['system_response'][:120]}…")
# ---------------------------------------------------------- decision
if score > 0.5:
decision_payload = {"action_type": "deny",
"parameters": {"reason": "fraud risk above threshold"}}
label = "DENY (fraud)"
else:
payout = round(claim_amount * 0.9, 2)
decision_payload = {"action_type": "approve",
"parameters": {"payout": payout}}
label = f"APPROVE (${payout:,.2f})"
print(f"\n[5] verdict β†’ {label}")
reply = await _exchange(ws, {"type": "step", "data": decision_payload})
out = reply["data"]
terminal = out["observation"]
reward = out.get("reward")
print(f" terminal = {terminal.get('is_terminal')}")
print(f" terminal_reason = {terminal.get('terminal_reason')}")
print(f" reward = {reward}")
await _exchange(ws, {"type": "close", "data": {}})
# --------------------------------------------------------- assertions
assert terminal.get("is_terminal") is True, "expected terminal observation"
assert reward is not None, "terminal step must return a reward"
print(f"\n{DIVIDER}\nsmoke test PASSED\n{DIVIDER}")
return 0
if __name__ == "__main__":
sys.exit(asyncio.run(run_episode()))