Spaces:
Running
Running
File size: 14,915 Bytes
03815d6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 | """Integration tests for the OpenEnv-compliant ChakravyuhOpenEnv wrapper.
Exercise the full reset/step/state contract without going through HTTP β
this verifies the Environment subclass obeys the OpenEnv interface
(reset returns ObsT, step returns ObsT with done/reward, state is
introspectable) and that 2-decision episodes produce coherent rewards.
"""
from __future__ import annotations
import pytest
from chakravyuh_env import (
ChakravyuhAction,
ChakravyuhObservation,
ChakravyuhOpenEnv,
ChakravyuhState,
)
from chakravyuh_env.openenv_models import ChakravyuhAction as _A
from chakravyuh_env.schemas import VictimProfile
# ---------------------------------------------------------------------------
# Basic contract: reset/step/state shapes
# ---------------------------------------------------------------------------
@pytest.mark.unit
def test_reset_returns_observation() -> None:
env = ChakravyuhOpenEnv()
obs = env.reset(seed=42)
assert isinstance(obs, ChakravyuhObservation)
# Invariant: reset() returns a non-terminal observation so the agent
# always gets at least one step (Gym/OpenEnv convention).
assert obs.done is False
assert obs.decision_index == 0
# After reset we should be past scammer(1) + victim(2) β turn == 2.
assert obs.turn == 2
senders = {m["sender"] for m in obs.chat_history}
assert senders == {"scammer", "victim"}
@pytest.mark.unit
def test_state_property_before_and_after_reset() -> None:
env = ChakravyuhOpenEnv()
state = env.state
assert isinstance(state, ChakravyuhState)
assert state.episode_id is None
env.reset(seed=7)
state = env.state
assert state.episode_id is not None
assert state.scam_category is not None
assert state.victim_profile == VictimProfile.SEMI_URBAN.value
@pytest.mark.unit
def test_step_requires_prior_reset() -> None:
env = ChakravyuhOpenEnv()
with pytest.raises(RuntimeError, match="reset"):
env.step(_A(score=0.5))
# ---------------------------------------------------------------------------
# Episode progression: 2-decision flow
# ---------------------------------------------------------------------------
def _run_episode(
env: ChakravyuhOpenEnv,
seed: int,
score1: float = 0.9,
score2: float = 0.9,
) -> ChakravyuhObservation:
obs = env.reset(seed=seed)
assert obs.done is False, "reset() must return a non-terminal observation"
obs = env.step(_A(score=score1, signals=["urgency"], explanation="t1"))
if obs.done:
return obs
obs = env.step(_A(score=score2, signals=["impersonation"], explanation="t2"))
return obs
@pytest.mark.integration
@pytest.mark.parametrize("seed", [1, 42, 99, 256, 1000])
def test_episode_eventually_terminates(seed: int) -> None:
env = ChakravyuhOpenEnv()
obs = _run_episode(env, seed=seed)
assert obs.done is True
assert obs.reward is not None
assert isinstance(obs.reward, float)
# Analyzer reward is in [-0.8, +1.4]-ish; loose bound here.
assert -5.0 < obs.reward < 5.0
@pytest.mark.integration
def test_high_suspicion_score_flags_scam() -> None:
env = ChakravyuhOpenEnv()
obs = env.reset(seed=42)
if obs.done:
pytest.skip("Episode ended at reset for this seed")
# Flag aggressively at both decision points.
env.step(_A(score=0.99, flag_threshold=0.5, signals=["urgency"], explanation="flag"))
final = env.step(
_A(score=0.99, flag_threshold=0.5, signals=["impersonation"], explanation="flag")
) if not env.state.done else None
state = env.state
assert state.analyzer_flagged is True
assert state.done is True
if final is not None:
assert final.done is True
assert final.outcome is not None
assert final.outcome.get("analyzer_flagged") is True
@pytest.mark.integration
def test_low_suspicion_score_does_not_flag() -> None:
env = ChakravyuhOpenEnv()
obs = env.reset(seed=42)
if obs.done:
pytest.skip("Episode ended at reset for this seed")
# Intentionally under-score at both decision points β no flag.
env.step(_A(score=0.10, flag_threshold=0.5, explanation="ignore"))
if not env.state.done:
env.step(_A(score=0.10, flag_threshold=0.5, explanation="ignore"))
assert env.state.analyzer_flagged is False
@pytest.mark.integration
def test_step_after_done_raises() -> None:
env = ChakravyuhOpenEnv()
obs = _run_episode(env, seed=3)
assert obs.done is True
with pytest.raises(RuntimeError, match="already done"):
env.step(_A(score=0.5))
# ---------------------------------------------------------------------------
# Reward & observation payload sanity
# ---------------------------------------------------------------------------
@pytest.mark.integration
def test_terminal_observation_includes_reward_breakdown() -> None:
env = ChakravyuhOpenEnv()
obs = _run_episode(env, seed=11, score1=0.9, score2=0.9)
assert obs.done is True
assert obs.reward_breakdown is not None
# New composable-rubric breakdown: one entry per child rubric +
# the weighted total + the weight map used, plus a legacy-analyzer
# reference value kept for comparison.
for key in (
"total",
"detection",
"missed_scam",
"false_positive",
"calibration",
"explanation",
"weights",
"legacy_analyzer",
):
assert key in obs.reward_breakdown, f"reward_breakdown missing {key}"
assert obs.outcome is not None
for key in ("money_extracted", "analyzer_flagged", "detected_by_turn"):
assert key in obs.outcome
@pytest.mark.integration
@pytest.mark.parametrize("seed", range(50))
def test_reset_is_always_non_terminal_across_seeds(seed: int) -> None:
"""Regression: reset() must never return done=True for any seed.
Previously a YOUNG_URBAN victim with low trust + turn-1 info request
could VictimRefuse on turn 2, returning a 0-step terminal episode
from reset() β violating Gym/OpenEnv convention and breaking training
loops that expect at least one step per trajectory.
"""
# Sweep across all three victim profiles since the refuse path depends
# on base trust level.
for profile in VictimProfile:
env = ChakravyuhOpenEnv(victim_profile=profile)
obs = env.reset(seed=seed)
assert obs.done is False, (
f"reset() returned done=True on seed={seed}, profile={profile.value} "
f"(chat_history={[m['text'][:50] for m in obs.chat_history]})"
)
assert obs.decision_index == 0
@pytest.mark.integration
def test_regulator_state_does_not_leak_across_resets() -> None:
"""Regression: the regulator's per-episode outcome buffer must be
cleared by reset(). Previously it was created in __init__ once, so
outcomes accumulated across resets within a single env instance β
polluting downstream rule-update logic under WebSocket session reuse.
"""
env = ChakravyuhOpenEnv()
# Episode 1
env.reset(seed=1)
env.step(_A(score=0.9))
if not env.state.done:
env.step(_A(score=0.9))
regulator_after_ep1 = env._regulator
assert regulator_after_ep1 is not None
assert len(regulator_after_ep1._outcome_buffer) == 1
# Episode 2 β fresh regulator must replace the old one.
env.reset(seed=2)
regulator_after_ep2 = env._regulator
assert regulator_after_ep2 is not None
assert regulator_after_ep2 is not regulator_after_ep1, (
"reset() must create a new regulator, not reuse the previous one"
)
assert len(regulator_after_ep2._outcome_buffer) == 0
@pytest.mark.integration
def test_concurrent_instances_have_isolated_regulators() -> None:
"""Two env instances created from the same factory must not share
regulator state β required for SUPPORTS_CONCURRENT_SESSIONS=True."""
env_a = ChakravyuhOpenEnv()
env_b = ChakravyuhOpenEnv()
env_a.reset(seed=1)
env_b.reset(seed=2)
env_a.step(_A(score=0.9))
if not env_a.state.done:
env_a.step(_A(score=0.9))
# env_b has had no step() calls β its regulator must still be empty
# even though env_a has logged an outcome.
assert env_a._regulator is not env_b._regulator
assert len(env_a._regulator._outcome_buffer) == 1
assert len(env_b._regulator._outcome_buffer) == 0
@pytest.mark.integration
def test_invalid_signal_name_raises() -> None:
env = ChakravyuhOpenEnv()
obs = env.reset(seed=42)
if obs.done:
pytest.skip("Episode ended at reset for this seed")
with pytest.raises(ValueError, match="AnalyzerSignal"):
env.step(_A(score=0.9, signals=["not_a_real_signal"]))
@pytest.mark.integration
def test_determinism_same_seed_same_transcript() -> None:
env1 = ChakravyuhOpenEnv()
env2 = ChakravyuhOpenEnv()
obs1 = env1.reset(seed=1234)
obs2 = env2.reset(seed=1234)
texts1 = [m["text"] for m in obs1.chat_history]
texts2 = [m["text"] for m in obs2.chat_history]
assert texts1 == texts2
@pytest.mark.unit
def test_observation_round_trips_through_json() -> None:
"""Observation must survive a JSON serialise / deserialise round-trip
so OpenEnv's wire transport is lossless."""
env = ChakravyuhOpenEnv()
obs = env.reset(seed=42)
raw = obs.model_dump_json()
rehydrated = ChakravyuhObservation.model_validate_json(raw)
assert rehydrated.turn == obs.turn
assert rehydrated.decision_index == obs.decision_index
assert rehydrated.episode_id == obs.episode_id
assert rehydrated.scam_category == obs.scam_category
assert rehydrated.victim_profile == obs.victim_profile
assert len(rehydrated.chat_history) == len(obs.chat_history)
assert rehydrated.schema_version == obs.schema_version
@pytest.mark.unit
def test_observation_carries_schema_version() -> None:
"""Every observation must carry the OpenEnv schema_version so old
training runs can detect wire-format mismatch on replay."""
from chakravyuh_env.openenv_models import CHAKRAVYUH_SCHEMA_VERSION
env = ChakravyuhOpenEnv()
obs = env.reset(seed=42)
assert obs.schema_version == CHAKRAVYUH_SCHEMA_VERSION
assert obs.schema_version == "0.2.0"
@pytest.mark.unit
def test_chat_turn_validator_documents_wire_shape() -> None:
"""The ChatTurn validator should accept any chat_history dict the env
actually emits β guards against drift between wire shape and docs."""
from chakravyuh_env.openenv_models import ChatTurn
env = ChakravyuhOpenEnv()
obs = env.reset(seed=42)
for turn_dict in obs.chat_history:
validated = ChatTurn.model_validate(turn_dict)
assert validated.sender in {
"scammer", "victim", "analyzer", "bank_monitor", "regulator"
}
assert validated.turn >= 0
@pytest.mark.unit
def test_reward_breakdown_validator_matches_terminal_obs_shape() -> None:
"""Terminal observations carry a reward_breakdown dict β its shape
must match the RewardBreakdown documented schema."""
from chakravyuh_env.openenv_models import RewardBreakdown
env = ChakravyuhOpenEnv()
env.reset(seed=42)
obs = env.step(_A(score=0.95, signals=["urgency"], explanation="OTP request"))
# Walk through to terminal if not yet done.
while not obs.done:
obs = env.step(_A(score=0.95, signals=["urgency"], explanation="next"))
if obs.reward_breakdown is not None:
# extra="ignore" lets unknown keys (e.g. composite_unweighted) pass.
validated = RewardBreakdown.model_validate(obs.reward_breakdown)
assert isinstance(validated.composite, float)
# ---------------------------------------------------------------------------
# Server factory smoke test β ensures create_app wiring doesn't explode.
# Runs as a pure import/invocation test; no HTTP server is actually bound.
# ---------------------------------------------------------------------------
@pytest.mark.integration
def test_fastapi_app_builds() -> None:
from server.app import app
assert app is not None
# ``/health`` and ``/schema`` are part of the stock OpenEnv HTTP surface.
routes = {getattr(r, "path", None) for r in app.routes}
assert "/health" in routes
assert "/schema" in routes
assert "/reset" in routes
assert "/step" in routes
assert "/state" in routes
@pytest.mark.integration
def test_websocket_full_episode_round_trip() -> None:
"""Exercise a real HTTP server + WebSocket client β the path judges use.
Starts uvicorn as a subprocess on a free port, connects via
``ChakravyuhEnvClient`` (OpenEnv's standard WS client), runs a 2-step
episode, and asserts terminal reward + outcome fields arrive intact.
"""
import signal
import socket
import subprocess
import sys
import time
from pathlib import Path
repo_root = Path(__file__).resolve().parent.parent
# Pick a free port by binding to 0 and releasing.
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
port = s.getsockname()[1]
proc = subprocess.Popen(
[
sys.executable,
"-m",
"uvicorn",
"server.app:app",
"--host",
"127.0.0.1",
"--port",
str(port),
"--log-level",
"error",
],
cwd=str(repo_root),
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
try:
# Poll /health until the server is ready (up to 10 s).
import urllib.request
deadline = time.time() + 10
url = f"http://127.0.0.1:{port}/health"
while time.time() < deadline:
try:
urllib.request.urlopen(url, timeout=0.5).read()
break
except Exception:
time.sleep(0.3)
else:
pytest.skip("Server failed to start in 10 s")
from chakravyuh_env.openenv_client import ChakravyuhEnvClient
with ChakravyuhEnvClient(base_url=f"http://127.0.0.1:{port}").sync() as env:
r = env.reset(seed=42)
assert not r.done
assert r.observation.turn >= 2
assert len(r.observation.chat_history) >= 2
r = env.step(_A(score=0.9, signals=["urgency"]))
if not r.done:
r = env.step(_A(score=0.9, signals=["impersonation"]))
assert r.done is True
assert r.reward is not None
assert r.observation.outcome is not None
assert r.observation.reward_breakdown is not None
assert r.observation.outcome.get("analyzer_flagged") is True
finally:
proc.send_signal(signal.SIGTERM)
try:
proc.wait(timeout=5)
except subprocess.TimeoutExpired:
proc.kill()
|