"""Tests for ActionMaskComputer -- pure logic, no env dependency.""" import numpy as np import pytest from types import SimpleNamespace from rl.action_mask import ActionMaskComputer from rl.feature_builder import ACTION_DECODE_TABLE, N_ACTIONS from app.models import ServiceType def _make_obs( escalation_budget=5, missing_doc_counts=None, urgent_counts=None, reserve_officers=3, allocations=None, active_cases_by_service=None, ): services = [s for s in ServiceType] missing_doc_counts = missing_doc_counts or {} urgent_counts = urgent_counts or {} active_cases_by_service = active_cases_by_service or {svc.value: 10 for svc in services} allocations = allocations or {svc: 1 for svc in services} snapshots = { svc.value: SimpleNamespace( service_type=svc, total_pending=active_cases_by_service.get(svc.value, 0), avg_waiting_days=3.0, urgent_pending=urgent_counts.get(svc.value, 2), blocked_missing_docs=missing_doc_counts.get(svc.value, 0), escalated_cases=0, public_stage_counts={}, ) for svc in services } return SimpleNamespace( queue_snapshots=snapshots, escalation_budget_remaining=escalation_budget, officer_pool=SimpleNamespace( total_officers=lambda: 10, allocated=allocations, idle_officers=reserve_officers, ), day=5, max_days=30, total_backlog=50, total_completed=20, total_sla_breaches=3, fairness_gap=0.1, last_action_valid=True, last_action_message="ok", ) @pytest.fixture def amc(): return ActionMaskComputer() def test_advance_time_always_valid(amc): assert amc.compute(_make_obs(), "balanced")[18] def test_escalate_blocked_when_budget_zero(amc): mask = amc.compute(_make_obs(escalation_budget=0, urgent_counts={"passport": 5}), "balanced") for idx, (t, _, _, _) in ACTION_DECODE_TABLE.items(): if t == "escalate_service": assert not mask[idx] def test_missing_docs_blocked_when_no_pending(amc): mask = amc.compute(_make_obs(missing_doc_counts={}), "balanced") for idx, (t, _, _, _) in ACTION_DECODE_TABLE.items(): if t == "request_missing_documents": assert not mask[idx] def test_missing_docs_valid_when_pending(amc): first_svc = list(ServiceType)[0].value mask = amc.compute(_make_obs(missing_doc_counts={first_svc: 3}), "balanced") for idx, (t, s, _, _) in ACTION_DECODE_TABLE.items(): if t == "request_missing_documents" and s == first_svc: assert mask[idx] def test_reallocate_blocked_when_source_has_no_alloc(amc): zero_alloc = {svc: 0 for svc in ServiceType} mask = amc.compute(_make_obs(allocations=zero_alloc), "balanced") for idx, (t, _, _, _) in ACTION_DECODE_TABLE.items(): if t == "reallocate_officers": assert not mask[idx] def test_assign_capacity_blocked_when_no_reserve(amc): mask = amc.compute(_make_obs(reserve_officers=0), "balanced") for idx, (t, _, _, _) in ACTION_DECODE_TABLE.items(): if t == "assign_capacity": assert not mask[idx] def test_reallocate_blocked_when_only_one_active_service(amc): first = list(ServiceType)[0].value active_cases = {svc.value: 0 for svc in ServiceType} active_cases[first] = 10 mask = amc.compute(_make_obs(active_cases_by_service=active_cases), "balanced") for idx, (t, _, _, _) in ACTION_DECODE_TABLE.items(): if t == "reallocate_officers": assert not mask[idx] def test_redundant_priority_mode_blocked(amc): mask = amc.compute(_make_obs(), current_priority_mode="urgent_first") assert not mask[0] def test_mask_length(amc): assert len(amc.compute(_make_obs(), "balanced")) == N_ACTIONS def test_at_least_one_valid_action(amc): assert amc.compute(_make_obs(), "balanced").any() def test_only_advance_time_when_backlog_zero(amc): obs = _make_obs(active_cases_by_service={svc.value: 0 for svc in ServiceType}) obs.total_backlog = 0 mask = amc.compute(obs, "balanced") assert mask[18] assert int(mask.sum()) == 1