File size: 4,091 Bytes
1cfeb15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python3
"""Smoke test that drives the gym through a full episode over WebSocket.

Run::

    python test_websocket.py            # talk to a local uvicorn
    CLAIMS_ENV_WS=wss://… python ...    # against the deployed Space

Prints a one-line summary per step and asserts on the basics (reset
returns a claim, terminal verdict produces a reward).
"""

from __future__ import annotations

import asyncio
import json
import os
import sys

import websockets


WS_URL = os.environ.get("CLAIMS_ENV_WS", "ws://127.0.0.1:7860/ws")
DIVIDER = "=" * 60


async def _exchange(ws, message: dict) -> dict:
    await ws.send(json.dumps(message))
    return json.loads(await ws.recv())


async def _step(ws, action_type: str, **parameters) -> dict:
    """Send one step and return the response payload."""
    payload = await _exchange(
        ws,
        {
            "type": "step",
            "data": {"action_type": action_type, "parameters": parameters},
        },
    )
    return payload


async def run_episode() -> int:
    print(DIVIDER)
    print(f"ClaimSense WebSocket smoke test → {WS_URL}")
    print(DIVIDER)

    async with websockets.connect(WS_URL) as ws:
        # ------------------------------------------------------------- reset
        reply = await _exchange(ws, {"type": "reset", "data": {}})
        if reply.get("type") == "error":
            print(f"reset failed: {reply['data']}")
            return 1

        obs = reply["data"]["observation"]
        claim_amount = float(obs["claim_amount_requested"])
        print("\n[1] reset")
        print(f"    claim_id        = {obs['claim_id']}")
        print(f"    claim_type      = {obs['claim_type']}")
        print(f"    claim_amount    = ${claim_amount:,.2f}")
        print(f"    description     = {obs['description'][:80]}…")

        # ------------------------------------------------------ query_policy
        reply = await _step(ws, "query_policy")
        print("\n[2] query_policy → "
              f"{reply['data']['observation']['system_response'][:100]}…")

        # -------------------------------------------------------- check_fraud
        reply = await _step(ws, "check_fraud")
        obs = reply["data"]["observation"]
        fraud = obs["revealed_info"].get("fraud_analysis", {})
        score = float(fraud.get("risk_score", 0))
        print(f"\n[3] check_fraud → risk_score={score:.2f} "
              f"({fraud.get('recommendation', '?')})")

        # ----------------------------------------------------- verify_purchase
        reply = await _step(ws, "verify_purchase")
        print("\n[4] verify_purchase → "
              f"{reply['data']['observation']['system_response'][:120]}…")

        # ---------------------------------------------------------- decision
        if score > 0.5:
            decision_payload = {"action_type": "deny",
                                "parameters": {"reason": "fraud risk above threshold"}}
            label = "DENY (fraud)"
        else:
            payout = round(claim_amount * 0.9, 2)
            decision_payload = {"action_type": "approve",
                                "parameters": {"payout": payout}}
            label = f"APPROVE (${payout:,.2f})"

        print(f"\n[5] verdict → {label}")
        reply = await _exchange(ws, {"type": "step", "data": decision_payload})
        out = reply["data"]
        terminal = out["observation"]
        reward = out.get("reward")
        print(f"    terminal        = {terminal.get('is_terminal')}")
        print(f"    terminal_reason = {terminal.get('terminal_reason')}")
        print(f"    reward          = {reward}")

        await _exchange(ws, {"type": "close", "data": {}})

        # --------------------------------------------------------- assertions
        assert terminal.get("is_terminal") is True, "expected terminal observation"
        assert reward is not None, "terminal step must return a reward"

    print(f"\n{DIVIDER}\nsmoke test PASSED\n{DIVIDER}")
    return 0


if __name__ == "__main__":
    sys.exit(asyncio.run(run_episode()))