File size: 14,915 Bytes
03815d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
"""Integration tests for the OpenEnv-compliant ChakravyuhOpenEnv wrapper.

Exercise the full reset/step/state contract without going through HTTP β€”
this verifies the Environment subclass obeys the OpenEnv interface
(reset returns ObsT, step returns ObsT with done/reward, state is
introspectable) and that 2-decision episodes produce coherent rewards.
"""

from __future__ import annotations

import pytest

from chakravyuh_env import (
    ChakravyuhAction,
    ChakravyuhObservation,
    ChakravyuhOpenEnv,
    ChakravyuhState,
)
from chakravyuh_env.openenv_models import ChakravyuhAction as _A
from chakravyuh_env.schemas import VictimProfile


# ---------------------------------------------------------------------------
# Basic contract: reset/step/state shapes
# ---------------------------------------------------------------------------


@pytest.mark.unit
def test_reset_returns_observation() -> None:
    env = ChakravyuhOpenEnv()
    obs = env.reset(seed=42)

    assert isinstance(obs, ChakravyuhObservation)
    # Invariant: reset() returns a non-terminal observation so the agent
    # always gets at least one step (Gym/OpenEnv convention).
    assert obs.done is False
    assert obs.decision_index == 0
    # After reset we should be past scammer(1) + victim(2) β†’ turn == 2.
    assert obs.turn == 2
    senders = {m["sender"] for m in obs.chat_history}
    assert senders == {"scammer", "victim"}


@pytest.mark.unit
def test_state_property_before_and_after_reset() -> None:
    env = ChakravyuhOpenEnv()
    state = env.state
    assert isinstance(state, ChakravyuhState)
    assert state.episode_id is None

    env.reset(seed=7)
    state = env.state
    assert state.episode_id is not None
    assert state.scam_category is not None
    assert state.victim_profile == VictimProfile.SEMI_URBAN.value


@pytest.mark.unit
def test_step_requires_prior_reset() -> None:
    env = ChakravyuhOpenEnv()
    with pytest.raises(RuntimeError, match="reset"):
        env.step(_A(score=0.5))


# ---------------------------------------------------------------------------
# Episode progression: 2-decision flow
# ---------------------------------------------------------------------------


def _run_episode(
    env: ChakravyuhOpenEnv,
    seed: int,
    score1: float = 0.9,
    score2: float = 0.9,
) -> ChakravyuhObservation:
    obs = env.reset(seed=seed)
    assert obs.done is False, "reset() must return a non-terminal observation"
    obs = env.step(_A(score=score1, signals=["urgency"], explanation="t1"))
    if obs.done:
        return obs
    obs = env.step(_A(score=score2, signals=["impersonation"], explanation="t2"))
    return obs


@pytest.mark.integration
@pytest.mark.parametrize("seed", [1, 42, 99, 256, 1000])
def test_episode_eventually_terminates(seed: int) -> None:
    env = ChakravyuhOpenEnv()
    obs = _run_episode(env, seed=seed)
    assert obs.done is True
    assert obs.reward is not None
    assert isinstance(obs.reward, float)
    # Analyzer reward is in [-0.8, +1.4]-ish; loose bound here.
    assert -5.0 < obs.reward < 5.0


@pytest.mark.integration
def test_high_suspicion_score_flags_scam() -> None:
    env = ChakravyuhOpenEnv()
    obs = env.reset(seed=42)
    if obs.done:
        pytest.skip("Episode ended at reset for this seed")

    # Flag aggressively at both decision points.
    env.step(_A(score=0.99, flag_threshold=0.5, signals=["urgency"], explanation="flag"))
    final = env.step(
        _A(score=0.99, flag_threshold=0.5, signals=["impersonation"], explanation="flag")
    ) if not env.state.done else None

    state = env.state
    assert state.analyzer_flagged is True
    assert state.done is True
    if final is not None:
        assert final.done is True
        assert final.outcome is not None
        assert final.outcome.get("analyzer_flagged") is True


@pytest.mark.integration
def test_low_suspicion_score_does_not_flag() -> None:
    env = ChakravyuhOpenEnv()
    obs = env.reset(seed=42)
    if obs.done:
        pytest.skip("Episode ended at reset for this seed")

    # Intentionally under-score at both decision points β†’ no flag.
    env.step(_A(score=0.10, flag_threshold=0.5, explanation="ignore"))
    if not env.state.done:
        env.step(_A(score=0.10, flag_threshold=0.5, explanation="ignore"))

    assert env.state.analyzer_flagged is False


@pytest.mark.integration
def test_step_after_done_raises() -> None:
    env = ChakravyuhOpenEnv()
    obs = _run_episode(env, seed=3)
    assert obs.done is True
    with pytest.raises(RuntimeError, match="already done"):
        env.step(_A(score=0.5))


# ---------------------------------------------------------------------------
# Reward & observation payload sanity
# ---------------------------------------------------------------------------


@pytest.mark.integration
def test_terminal_observation_includes_reward_breakdown() -> None:
    env = ChakravyuhOpenEnv()
    obs = _run_episode(env, seed=11, score1=0.9, score2=0.9)
    assert obs.done is True
    assert obs.reward_breakdown is not None
    # New composable-rubric breakdown: one entry per child rubric +
    # the weighted total + the weight map used, plus a legacy-analyzer
    # reference value kept for comparison.
    for key in (
        "total",
        "detection",
        "missed_scam",
        "false_positive",
        "calibration",
        "explanation",
        "weights",
        "legacy_analyzer",
    ):
        assert key in obs.reward_breakdown, f"reward_breakdown missing {key}"
    assert obs.outcome is not None
    for key in ("money_extracted", "analyzer_flagged", "detected_by_turn"):
        assert key in obs.outcome


@pytest.mark.integration
@pytest.mark.parametrize("seed", range(50))
def test_reset_is_always_non_terminal_across_seeds(seed: int) -> None:
    """Regression: reset() must never return done=True for any seed.

    Previously a YOUNG_URBAN victim with low trust + turn-1 info request
    could VictimRefuse on turn 2, returning a 0-step terminal episode
    from reset() β€” violating Gym/OpenEnv convention and breaking training
    loops that expect at least one step per trajectory.
    """
    # Sweep across all three victim profiles since the refuse path depends
    # on base trust level.
    for profile in VictimProfile:
        env = ChakravyuhOpenEnv(victim_profile=profile)
        obs = env.reset(seed=seed)
        assert obs.done is False, (
            f"reset() returned done=True on seed={seed}, profile={profile.value} "
            f"(chat_history={[m['text'][:50] for m in obs.chat_history]})"
        )
        assert obs.decision_index == 0


@pytest.mark.integration
def test_regulator_state_does_not_leak_across_resets() -> None:
    """Regression: the regulator's per-episode outcome buffer must be
    cleared by reset(). Previously it was created in __init__ once, so
    outcomes accumulated across resets within a single env instance β€”
    polluting downstream rule-update logic under WebSocket session reuse.
    """
    env = ChakravyuhOpenEnv()

    # Episode 1
    env.reset(seed=1)
    env.step(_A(score=0.9))
    if not env.state.done:
        env.step(_A(score=0.9))
    regulator_after_ep1 = env._regulator
    assert regulator_after_ep1 is not None
    assert len(regulator_after_ep1._outcome_buffer) == 1

    # Episode 2 β€” fresh regulator must replace the old one.
    env.reset(seed=2)
    regulator_after_ep2 = env._regulator
    assert regulator_after_ep2 is not None
    assert regulator_after_ep2 is not regulator_after_ep1, (
        "reset() must create a new regulator, not reuse the previous one"
    )
    assert len(regulator_after_ep2._outcome_buffer) == 0


@pytest.mark.integration
def test_concurrent_instances_have_isolated_regulators() -> None:
    """Two env instances created from the same factory must not share
    regulator state β€” required for SUPPORTS_CONCURRENT_SESSIONS=True."""
    env_a = ChakravyuhOpenEnv()
    env_b = ChakravyuhOpenEnv()

    env_a.reset(seed=1)
    env_b.reset(seed=2)

    env_a.step(_A(score=0.9))
    if not env_a.state.done:
        env_a.step(_A(score=0.9))

    # env_b has had no step() calls β†’ its regulator must still be empty
    # even though env_a has logged an outcome.
    assert env_a._regulator is not env_b._regulator
    assert len(env_a._regulator._outcome_buffer) == 1
    assert len(env_b._regulator._outcome_buffer) == 0


@pytest.mark.integration
def test_invalid_signal_name_raises() -> None:
    env = ChakravyuhOpenEnv()
    obs = env.reset(seed=42)
    if obs.done:
        pytest.skip("Episode ended at reset for this seed")
    with pytest.raises(ValueError, match="AnalyzerSignal"):
        env.step(_A(score=0.9, signals=["not_a_real_signal"]))


@pytest.mark.integration
def test_determinism_same_seed_same_transcript() -> None:
    env1 = ChakravyuhOpenEnv()
    env2 = ChakravyuhOpenEnv()
    obs1 = env1.reset(seed=1234)
    obs2 = env2.reset(seed=1234)

    texts1 = [m["text"] for m in obs1.chat_history]
    texts2 = [m["text"] for m in obs2.chat_history]
    assert texts1 == texts2


@pytest.mark.unit
def test_observation_round_trips_through_json() -> None:
    """Observation must survive a JSON serialise / deserialise round-trip
    so OpenEnv's wire transport is lossless."""
    env = ChakravyuhOpenEnv()
    obs = env.reset(seed=42)
    raw = obs.model_dump_json()
    rehydrated = ChakravyuhObservation.model_validate_json(raw)

    assert rehydrated.turn == obs.turn
    assert rehydrated.decision_index == obs.decision_index
    assert rehydrated.episode_id == obs.episode_id
    assert rehydrated.scam_category == obs.scam_category
    assert rehydrated.victim_profile == obs.victim_profile
    assert len(rehydrated.chat_history) == len(obs.chat_history)
    assert rehydrated.schema_version == obs.schema_version


@pytest.mark.unit
def test_observation_carries_schema_version() -> None:
    """Every observation must carry the OpenEnv schema_version so old
    training runs can detect wire-format mismatch on replay."""
    from chakravyuh_env.openenv_models import CHAKRAVYUH_SCHEMA_VERSION

    env = ChakravyuhOpenEnv()
    obs = env.reset(seed=42)
    assert obs.schema_version == CHAKRAVYUH_SCHEMA_VERSION
    assert obs.schema_version == "0.2.0"


@pytest.mark.unit
def test_chat_turn_validator_documents_wire_shape() -> None:
    """The ChatTurn validator should accept any chat_history dict the env
    actually emits β€” guards against drift between wire shape and docs."""
    from chakravyuh_env.openenv_models import ChatTurn

    env = ChakravyuhOpenEnv()
    obs = env.reset(seed=42)
    for turn_dict in obs.chat_history:
        validated = ChatTurn.model_validate(turn_dict)
        assert validated.sender in {
            "scammer", "victim", "analyzer", "bank_monitor", "regulator"
        }
        assert validated.turn >= 0


@pytest.mark.unit
def test_reward_breakdown_validator_matches_terminal_obs_shape() -> None:
    """Terminal observations carry a reward_breakdown dict β€” its shape
    must match the RewardBreakdown documented schema."""
    from chakravyuh_env.openenv_models import RewardBreakdown

    env = ChakravyuhOpenEnv()
    env.reset(seed=42)
    obs = env.step(_A(score=0.95, signals=["urgency"], explanation="OTP request"))
    # Walk through to terminal if not yet done.
    while not obs.done:
        obs = env.step(_A(score=0.95, signals=["urgency"], explanation="next"))
    if obs.reward_breakdown is not None:
        # extra="ignore" lets unknown keys (e.g. composite_unweighted) pass.
        validated = RewardBreakdown.model_validate(obs.reward_breakdown)
        assert isinstance(validated.composite, float)


# ---------------------------------------------------------------------------
# Server factory smoke test β€” ensures create_app wiring doesn't explode.
# Runs as a pure import/invocation test; no HTTP server is actually bound.
# ---------------------------------------------------------------------------


@pytest.mark.integration
def test_fastapi_app_builds() -> None:
    from server.app import app

    assert app is not None
    # ``/health`` and ``/schema`` are part of the stock OpenEnv HTTP surface.
    routes = {getattr(r, "path", None) for r in app.routes}
    assert "/health" in routes
    assert "/schema" in routes
    assert "/reset" in routes
    assert "/step" in routes
    assert "/state" in routes


@pytest.mark.integration
def test_websocket_full_episode_round_trip() -> None:
    """Exercise a real HTTP server + WebSocket client β†’ the path judges use.

    Starts uvicorn as a subprocess on a free port, connects via
    ``ChakravyuhEnvClient`` (OpenEnv's standard WS client), runs a 2-step
    episode, and asserts terminal reward + outcome fields arrive intact.
    """
    import signal
    import socket
    import subprocess
    import sys
    import time
    from pathlib import Path

    repo_root = Path(__file__).resolve().parent.parent

    # Pick a free port by binding to 0 and releasing.
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("127.0.0.1", 0))
        port = s.getsockname()[1]

    proc = subprocess.Popen(
        [
            sys.executable,
            "-m",
            "uvicorn",
            "server.app:app",
            "--host",
            "127.0.0.1",
            "--port",
            str(port),
            "--log-level",
            "error",
        ],
        cwd=str(repo_root),
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL,
    )
    try:
        # Poll /health until the server is ready (up to 10 s).
        import urllib.request

        deadline = time.time() + 10
        url = f"http://127.0.0.1:{port}/health"
        while time.time() < deadline:
            try:
                urllib.request.urlopen(url, timeout=0.5).read()
                break
            except Exception:
                time.sleep(0.3)
        else:
            pytest.skip("Server failed to start in 10 s")

        from chakravyuh_env.openenv_client import ChakravyuhEnvClient

        with ChakravyuhEnvClient(base_url=f"http://127.0.0.1:{port}").sync() as env:
            r = env.reset(seed=42)
            assert not r.done
            assert r.observation.turn >= 2
            assert len(r.observation.chat_history) >= 2

            r = env.step(_A(score=0.9, signals=["urgency"]))
            if not r.done:
                r = env.step(_A(score=0.9, signals=["impersonation"]))

            assert r.done is True
            assert r.reward is not None
            assert r.observation.outcome is not None
            assert r.observation.reward_breakdown is not None
            assert r.observation.outcome.get("analyzer_flagged") is True
    finally:
        proc.send_signal(signal.SIGTERM)
        try:
            proc.wait(timeout=5)
        except subprocess.TimeoutExpired:
            proc.kill()