"""pytest suite for the ClaimSense adjudication gym. Run with:: pytest tests/ -v Imports use the legacy ``ClaimsAction``/``ClaimsEnvironment`` names to exercise the backwards-compatibility aliases as well as the underlying implementation. """ from __future__ import annotations import pytest from claims_env.models import ClaimsAction, ClaimsObservation from claims_env.server.claims_environment import ( AdjudicationGym, ClaimsEnvironment, ) from claims_env.server.mock_systems import ( CLAIM_SCENARIOS, MockFraudAPI, MockPolicyDB, get_scenario_by_index, ) # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture def simple_env() -> ClaimsEnvironment: """A gym pinned to scenario 0 (clean approval).""" env = ClaimsEnvironment(scenario_index=0) env.reset() return env @pytest.fixture def fraud_env() -> ClaimsEnvironment: """A gym pinned to scenario 2 (staged-accident fraud).""" env = ClaimsEnvironment(scenario_index=2) env.reset() return env # --------------------------------------------------------------------------- # Reset behaviour # --------------------------------------------------------------------------- class TestResetSurface: def test_alias_resolves_to_implementation(self) -> None: assert ClaimsEnvironment is AdjudicationGym def test_reset_returns_observation_shape(self) -> None: obs = ClaimsEnvironment(scenario_index=0).reset() assert isinstance(obs, ClaimsObservation) assert obs.claim_id and obs.claim_type assert obs.claim_amount_requested > 0 assert obs.is_terminal is False assert obs.available_actions, "available_actions should be populated" def test_reset_seeds_episode_meta(self, simple_env: ClaimsEnvironment) -> None: assert simple_env.state.actions_taken == 0 assert simple_env.state.queries_made == 0 assert simple_env.state.total_reward == 0.0 # --------------------------------------------------------------------------- # Information-gathering verbs # --------------------------------------------------------------------------- class TestQueryActions: def test_query_policy_marks_state(self, simple_env: ClaimsEnvironment) -> None: obs = simple_env.step(ClaimsAction(action_type="query_policy")) assert obs.action_success assert simple_env.state.policy_queried assert "policy" in obs.system_response.lower() or "coverage" in obs.system_response.lower() def test_check_fraud_returns_signals(self, fraud_env: ClaimsEnvironment) -> None: obs = fraud_env.step(ClaimsAction(action_type="check_fraud")) assert obs.action_success assert fraud_env.state.fraud_checked assert "risk" in obs.system_response.lower() or "fraud" in obs.system_response.lower() def test_query_steps_increment_counters( self, simple_env: ClaimsEnvironment ) -> None: simple_env.step(ClaimsAction(action_type="query_policy")) simple_env.step(ClaimsAction(action_type="check_fraud")) assert simple_env.state.actions_taken == 2 assert simple_env.state.queries_made == 2 # --------------------------------------------------------------------------- # Terminal verbs and reward shaping # --------------------------------------------------------------------------- class TestTerminalVerbs: @pytest.mark.parametrize( "verb,reason_substring", [ ("approve", "approved"), ("deny", "denied"), ("escalate", "escalat"), ], ) def test_terminals_short_circuit( self, simple_env: ClaimsEnvironment, verb: str, reason_substring: str ) -> None: params = {"payout": 3000.0} if verb == "approve" else {"reason": "test"} obs = simple_env.step(ClaimsAction(action_type=verb, parameters=params)) assert obs.is_terminal assert reason_substring in obs.terminal_reason.lower() def test_correct_approval_yields_positive_reward( self, simple_env: ClaimsEnvironment ) -> None: simple_env.step(ClaimsAction(action_type="query_policy")) simple_env.step( ClaimsAction(action_type="approve", parameters={"payout": 3000.0}) ) assert simple_env.state.total_reward > 0 assert simple_env.state.correctness_reward > 0 def test_catching_fraud_grants_bonus( self, fraud_env: ClaimsEnvironment ) -> None: fraud_env.step( ClaimsAction(action_type="deny", parameters={"reason": "fraud"}) ) assert fraud_env.state.fraud_detection_reward > 0 def test_missing_fraud_incurs_penalty( self, fraud_env: ClaimsEnvironment ) -> None: fraud_env.step( ClaimsAction(action_type="approve", parameters={"payout": 12000.0}) ) assert fraud_env.state.fraud_detection_reward < 0 # --------------------------------------------------------------------------- # Error handling # --------------------------------------------------------------------------- def test_unknown_action_returns_error_observation( simple_env: ClaimsEnvironment, ) -> None: obs = simple_env.step(ClaimsAction(action_type="not_a_real_verb")) assert obs.action_success is False assert "error" in obs.system_response.lower() # --------------------------------------------------------------------------- # Backend stubs # --------------------------------------------------------------------------- class TestBackendStubs: def test_policy_lookup_returns_expected_keys(self) -> None: case = get_scenario_by_index(0) result = MockPolicyDB(case).lookup_policy() for key in ("policy_id", "policy_status", "coverage_limit", "deductible"): assert key in result def test_fraud_signal_score_is_bounded(self) -> None: case = get_scenario_by_index(2) result = MockFraudAPI(case).check_fraud_signals() assert 0.0 <= result["risk_score"] <= 1.0 assert "flags" in result and "recommendation" in result # --------------------------------------------------------------------------- # Library coverage # --------------------------------------------------------------------------- class TestCaseLibrary: def test_each_case_loads(self) -> None: for i, case in enumerate(CLAIM_SCENARIOS): env = ClaimsEnvironment(scenario_index=i) obs = env.reset() assert obs.claim_id == case.claim_id def test_library_spans_required_verdicts(self) -> None: verdicts = {case.true_verdict for case in CLAIM_SCENARIOS} assert {"approve", "deny", "partial_approve"} <= verdicts def test_library_has_fraud_examples(self) -> None: fraud_count = sum(1 for case in CLAIM_SCENARIOS if case.is_fraud) assert fraud_count >= 2