claims-env / tests /test_environment.py
akhiilll's picture
Deploy ClaimSense adjudication gym
1cfeb15 verified
"""pytest suite for the ClaimSense adjudication gym.
Run with::
pytest tests/ -v
Imports use the legacy ``ClaimsAction``/``ClaimsEnvironment`` names to
exercise the backwards-compatibility aliases as well as the underlying
implementation.
"""
from __future__ import annotations
import pytest
from claims_env.models import ClaimsAction, ClaimsObservation
from claims_env.server.claims_environment import (
AdjudicationGym,
ClaimsEnvironment,
)
from claims_env.server.mock_systems import (
CLAIM_SCENARIOS,
MockFraudAPI,
MockPolicyDB,
get_scenario_by_index,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def simple_env() -> ClaimsEnvironment:
"""A gym pinned to scenario 0 (clean approval)."""
env = ClaimsEnvironment(scenario_index=0)
env.reset()
return env
@pytest.fixture
def fraud_env() -> ClaimsEnvironment:
"""A gym pinned to scenario 2 (staged-accident fraud)."""
env = ClaimsEnvironment(scenario_index=2)
env.reset()
return env
# ---------------------------------------------------------------------------
# Reset behaviour
# ---------------------------------------------------------------------------
class TestResetSurface:
def test_alias_resolves_to_implementation(self) -> None:
assert ClaimsEnvironment is AdjudicationGym
def test_reset_returns_observation_shape(self) -> None:
obs = ClaimsEnvironment(scenario_index=0).reset()
assert isinstance(obs, ClaimsObservation)
assert obs.claim_id and obs.claim_type
assert obs.claim_amount_requested > 0
assert obs.is_terminal is False
assert obs.available_actions, "available_actions should be populated"
def test_reset_seeds_episode_meta(self, simple_env: ClaimsEnvironment) -> None:
assert simple_env.state.actions_taken == 0
assert simple_env.state.queries_made == 0
assert simple_env.state.total_reward == 0.0
# ---------------------------------------------------------------------------
# Information-gathering verbs
# ---------------------------------------------------------------------------
class TestQueryActions:
def test_query_policy_marks_state(self, simple_env: ClaimsEnvironment) -> None:
obs = simple_env.step(ClaimsAction(action_type="query_policy"))
assert obs.action_success
assert simple_env.state.policy_queried
assert "policy" in obs.system_response.lower() or "coverage" in obs.system_response.lower()
def test_check_fraud_returns_signals(self, fraud_env: ClaimsEnvironment) -> None:
obs = fraud_env.step(ClaimsAction(action_type="check_fraud"))
assert obs.action_success
assert fraud_env.state.fraud_checked
assert "risk" in obs.system_response.lower() or "fraud" in obs.system_response.lower()
def test_query_steps_increment_counters(
self, simple_env: ClaimsEnvironment
) -> None:
simple_env.step(ClaimsAction(action_type="query_policy"))
simple_env.step(ClaimsAction(action_type="check_fraud"))
assert simple_env.state.actions_taken == 2
assert simple_env.state.queries_made == 2
# ---------------------------------------------------------------------------
# Terminal verbs and reward shaping
# ---------------------------------------------------------------------------
class TestTerminalVerbs:
@pytest.mark.parametrize(
"verb,reason_substring",
[
("approve", "approved"),
("deny", "denied"),
("escalate", "escalat"),
],
)
def test_terminals_short_circuit(
self, simple_env: ClaimsEnvironment, verb: str, reason_substring: str
) -> None:
params = {"payout": 3000.0} if verb == "approve" else {"reason": "test"}
obs = simple_env.step(ClaimsAction(action_type=verb, parameters=params))
assert obs.is_terminal
assert reason_substring in obs.terminal_reason.lower()
def test_correct_approval_yields_positive_reward(
self, simple_env: ClaimsEnvironment
) -> None:
simple_env.step(ClaimsAction(action_type="query_policy"))
simple_env.step(
ClaimsAction(action_type="approve", parameters={"payout": 3000.0})
)
assert simple_env.state.total_reward > 0
assert simple_env.state.correctness_reward > 0
def test_catching_fraud_grants_bonus(
self, fraud_env: ClaimsEnvironment
) -> None:
fraud_env.step(
ClaimsAction(action_type="deny", parameters={"reason": "fraud"})
)
assert fraud_env.state.fraud_detection_reward > 0
def test_missing_fraud_incurs_penalty(
self, fraud_env: ClaimsEnvironment
) -> None:
fraud_env.step(
ClaimsAction(action_type="approve", parameters={"payout": 12000.0})
)
assert fraud_env.state.fraud_detection_reward < 0
# ---------------------------------------------------------------------------
# Error handling
# ---------------------------------------------------------------------------
def test_unknown_action_returns_error_observation(
simple_env: ClaimsEnvironment,
) -> None:
obs = simple_env.step(ClaimsAction(action_type="not_a_real_verb"))
assert obs.action_success is False
assert "error" in obs.system_response.lower()
# ---------------------------------------------------------------------------
# Backend stubs
# ---------------------------------------------------------------------------
class TestBackendStubs:
def test_policy_lookup_returns_expected_keys(self) -> None:
case = get_scenario_by_index(0)
result = MockPolicyDB(case).lookup_policy()
for key in ("policy_id", "policy_status", "coverage_limit", "deductible"):
assert key in result
def test_fraud_signal_score_is_bounded(self) -> None:
case = get_scenario_by_index(2)
result = MockFraudAPI(case).check_fraud_signals()
assert 0.0 <= result["risk_score"] <= 1.0
assert "flags" in result and "recommendation" in result
# ---------------------------------------------------------------------------
# Library coverage
# ---------------------------------------------------------------------------
class TestCaseLibrary:
def test_each_case_loads(self) -> None:
for i, case in enumerate(CLAIM_SCENARIOS):
env = ClaimsEnvironment(scenario_index=i)
obs = env.reset()
assert obs.claim_id == case.claim_id
def test_library_spans_required_verdicts(self) -> None:
verdicts = {case.true_verdict for case in CLAIM_SCENARIOS}
assert {"approve", "deny", "partial_approve"} <= verdicts
def test_library_has_fraud_examples(self) -> None:
fraud_count = sum(1 for case in CLAIM_SCENARIOS if case.is_fraud)
assert fraud_count >= 2