File size: 4,221 Bytes
df97e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Tests for ActionMaskComputer -- pure logic, no env dependency."""

import numpy as np
import pytest
from types import SimpleNamespace

from rl.action_mask import ActionMaskComputer
from rl.feature_builder import ACTION_DECODE_TABLE, N_ACTIONS
from app.models import ServiceType


def _make_obs(
    escalation_budget=5,
    missing_doc_counts=None,
    urgent_counts=None,
    reserve_officers=3,
    allocations=None,
    active_cases_by_service=None,
):
    services           = [s for s in ServiceType]
    missing_doc_counts = missing_doc_counts or {}
    urgent_counts      = urgent_counts or {}
    active_cases_by_service = active_cases_by_service or {svc.value: 10 for svc in services}
    allocations = allocations or {svc: 1 for svc in services}
    snapshots = {
        svc.value: SimpleNamespace(
            service_type=svc,
            total_pending=active_cases_by_service.get(svc.value, 0),
            avg_waiting_days=3.0,
            urgent_pending=urgent_counts.get(svc.value, 2),
            blocked_missing_docs=missing_doc_counts.get(svc.value, 0),
            escalated_cases=0,
            public_stage_counts={},
        )
        for svc in services
    }
    return SimpleNamespace(
        queue_snapshots=snapshots,
        escalation_budget_remaining=escalation_budget,
        officer_pool=SimpleNamespace(
            total_officers=lambda: 10,
            allocated=allocations,
            idle_officers=reserve_officers,
        ),
        day=5, max_days=30, total_backlog=50, total_completed=20,
        total_sla_breaches=3, fairness_gap=0.1,
        last_action_valid=True, last_action_message="ok",
    )


@pytest.fixture
def amc():
    return ActionMaskComputer()


def test_advance_time_always_valid(amc):
    assert amc.compute(_make_obs(), "balanced")[18]


def test_escalate_blocked_when_budget_zero(amc):
    mask = amc.compute(_make_obs(escalation_budget=0, urgent_counts={"passport": 5}), "balanced")
    for idx, (t, _, _, _) in ACTION_DECODE_TABLE.items():
        if t == "escalate_service":
            assert not mask[idx]


def test_missing_docs_blocked_when_no_pending(amc):
    mask = amc.compute(_make_obs(missing_doc_counts={}), "balanced")
    for idx, (t, _, _, _) in ACTION_DECODE_TABLE.items():
        if t == "request_missing_documents":
            assert not mask[idx]


def test_missing_docs_valid_when_pending(amc):
    first_svc = list(ServiceType)[0].value
    mask = amc.compute(_make_obs(missing_doc_counts={first_svc: 3}), "balanced")
    for idx, (t, s, _, _) in ACTION_DECODE_TABLE.items():
        if t == "request_missing_documents" and s == first_svc:
            assert mask[idx]


def test_reallocate_blocked_when_source_has_no_alloc(amc):
    zero_alloc = {svc: 0 for svc in ServiceType}
    mask = amc.compute(_make_obs(allocations=zero_alloc), "balanced")
    for idx, (t, _, _, _) in ACTION_DECODE_TABLE.items():
        if t == "reallocate_officers":
            assert not mask[idx]


def test_assign_capacity_blocked_when_no_reserve(amc):
    mask = amc.compute(_make_obs(reserve_officers=0), "balanced")
    for idx, (t, _, _, _) in ACTION_DECODE_TABLE.items():
        if t == "assign_capacity":
            assert not mask[idx]


def test_reallocate_blocked_when_only_one_active_service(amc):
    first = list(ServiceType)[0].value
    active_cases = {svc.value: 0 for svc in ServiceType}
    active_cases[first] = 10
    mask = amc.compute(_make_obs(active_cases_by_service=active_cases), "balanced")
    for idx, (t, _, _, _) in ACTION_DECODE_TABLE.items():
        if t == "reallocate_officers":
            assert not mask[idx]


def test_redundant_priority_mode_blocked(amc):
    mask = amc.compute(_make_obs(), current_priority_mode="urgent_first")
    assert not mask[0]


def test_mask_length(amc):
    assert len(amc.compute(_make_obs(), "balanced")) == N_ACTIONS


def test_at_least_one_valid_action(amc):
    assert amc.compute(_make_obs(), "balanced").any()


def test_only_advance_time_when_backlog_zero(amc):
    obs = _make_obs(active_cases_by_service={svc.value: 0 for svc in ServiceType})
    obs.total_backlog = 0
    mask = amc.compute(obs, "balanced")
    assert mask[18]
    assert int(mask.sum()) == 1