Spaces:
Running
Running
File size: 4,221 Bytes
df97e68 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | """Tests for ActionMaskComputer -- pure logic, no env dependency."""
import numpy as np
import pytest
from types import SimpleNamespace
from rl.action_mask import ActionMaskComputer
from rl.feature_builder import ACTION_DECODE_TABLE, N_ACTIONS
from app.models import ServiceType
def _make_obs(
escalation_budget=5,
missing_doc_counts=None,
urgent_counts=None,
reserve_officers=3,
allocations=None,
active_cases_by_service=None,
):
services = [s for s in ServiceType]
missing_doc_counts = missing_doc_counts or {}
urgent_counts = urgent_counts or {}
active_cases_by_service = active_cases_by_service or {svc.value: 10 for svc in services}
allocations = allocations or {svc: 1 for svc in services}
snapshots = {
svc.value: SimpleNamespace(
service_type=svc,
total_pending=active_cases_by_service.get(svc.value, 0),
avg_waiting_days=3.0,
urgent_pending=urgent_counts.get(svc.value, 2),
blocked_missing_docs=missing_doc_counts.get(svc.value, 0),
escalated_cases=0,
public_stage_counts={},
)
for svc in services
}
return SimpleNamespace(
queue_snapshots=snapshots,
escalation_budget_remaining=escalation_budget,
officer_pool=SimpleNamespace(
total_officers=lambda: 10,
allocated=allocations,
idle_officers=reserve_officers,
),
day=5, max_days=30, total_backlog=50, total_completed=20,
total_sla_breaches=3, fairness_gap=0.1,
last_action_valid=True, last_action_message="ok",
)
@pytest.fixture
def amc():
return ActionMaskComputer()
def test_advance_time_always_valid(amc):
assert amc.compute(_make_obs(), "balanced")[18]
def test_escalate_blocked_when_budget_zero(amc):
mask = amc.compute(_make_obs(escalation_budget=0, urgent_counts={"passport": 5}), "balanced")
for idx, (t, _, _, _) in ACTION_DECODE_TABLE.items():
if t == "escalate_service":
assert not mask[idx]
def test_missing_docs_blocked_when_no_pending(amc):
mask = amc.compute(_make_obs(missing_doc_counts={}), "balanced")
for idx, (t, _, _, _) in ACTION_DECODE_TABLE.items():
if t == "request_missing_documents":
assert not mask[idx]
def test_missing_docs_valid_when_pending(amc):
first_svc = list(ServiceType)[0].value
mask = amc.compute(_make_obs(missing_doc_counts={first_svc: 3}), "balanced")
for idx, (t, s, _, _) in ACTION_DECODE_TABLE.items():
if t == "request_missing_documents" and s == first_svc:
assert mask[idx]
def test_reallocate_blocked_when_source_has_no_alloc(amc):
zero_alloc = {svc: 0 for svc in ServiceType}
mask = amc.compute(_make_obs(allocations=zero_alloc), "balanced")
for idx, (t, _, _, _) in ACTION_DECODE_TABLE.items():
if t == "reallocate_officers":
assert not mask[idx]
def test_assign_capacity_blocked_when_no_reserve(amc):
mask = amc.compute(_make_obs(reserve_officers=0), "balanced")
for idx, (t, _, _, _) in ACTION_DECODE_TABLE.items():
if t == "assign_capacity":
assert not mask[idx]
def test_reallocate_blocked_when_only_one_active_service(amc):
first = list(ServiceType)[0].value
active_cases = {svc.value: 0 for svc in ServiceType}
active_cases[first] = 10
mask = amc.compute(_make_obs(active_cases_by_service=active_cases), "balanced")
for idx, (t, _, _, _) in ACTION_DECODE_TABLE.items():
if t == "reallocate_officers":
assert not mask[idx]
def test_redundant_priority_mode_blocked(amc):
mask = amc.compute(_make_obs(), current_priority_mode="urgent_first")
assert not mask[0]
def test_mask_length(amc):
assert len(amc.compute(_make_obs(), "balanced")) == N_ACTIONS
def test_at_least_one_valid_action(amc):
assert amc.compute(_make_obs(), "balanced").any()
def test_only_advance_time_when_backlog_zero(amc):
obs = _make_obs(active_cases_by_service={svc.value: 0 for svc in ServiceType})
obs.total_backlog = 0
mask = amc.compute(obs, "balanced")
assert mask[18]
assert int(mask.sum()) == 1
|