File size: 4,731 Bytes
df97e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Computes a boolean action mask (length N_ACTIONS) from the current observation.
True  = action is structurally valid right now.
False = action is impossible/wasteful; MaskablePPO will zero its logit.
"""

from __future__ import annotations

import numpy as np
from app.models import ObservationModel
from rl.feature_builder import ACTION_DECODE_TABLE, N_ACTIONS


class ActionMaskComputer:
    """
    Usage:
        amc = ActionMaskComputer()
        mask = amc.compute(obs, current_priority_mode)
    """

    def compute(
        self,
        obs: ObservationModel,
        current_priority_mode: str = "balanced",
    ) -> np.ndarray:
        mask = np.ones(N_ACTIONS, dtype=bool)
        total_backlog = int(getattr(obs, "total_backlog", 0) or 0)

        # Prevent reward farming with no-op control actions when nothing is queued.
        # In this state, time must advance to generate arrivals and meaningful decisions.
        if total_backlog <= 0:
            mask[:] = False
            for action_idx, (action_type, _service, _pm, _delta) in ACTION_DECODE_TABLE.items():
                if action_type == "advance_time":
                    mask[action_idx] = True
                    break
            return mask

        queue_snaps = obs.queue_snapshots.values() if isinstance(obs.queue_snapshots, dict) else obs.queue_snapshots
        queue_snaps = list(queue_snaps)
        snapshots = {
            (snap.service_type.value if hasattr(snap.service_type, "value") else snap.service_type): snap 
            for snap in queue_snaps
        }
        active_services = {
            service for service, snap in snapshots.items()
            if getattr(snap, "total_pending", getattr(snap, "active_cases", 0)) > 0
        }
        escalation_budget = obs.escalation_budget_remaining

        services_with_missing_docs = {
            (snap.service_type.value if hasattr(snap.service_type, "value") else snap.service_type) 
            for snap in queue_snaps
            if getattr(snap, "blocked_missing_docs", getattr(snap, "missing_docs_cases", 0)) > 0
        }
        services_with_escalatable = {
            (snap.service_type.value if hasattr(snap.service_type, "value") else snap.service_type)
            for snap in queue_snaps
            if (getattr(snap, "total_pending", getattr(snap, "active_cases", 0)) - getattr(snap, "urgent_pending", getattr(snap, "escalated_cases", 0))) > 0
        }

        allocations = {}
        for service_key, value in (getattr(obs.officer_pool, "allocated", getattr(obs.officer_pool, "allocations", {})) or {}).items():
            name = service_key.value if hasattr(service_key, "value") else str(service_key)
            allocations[name] = int(value)

        idle_officers = getattr(obs.officer_pool, "idle_officers", getattr(obs.officer_pool, "reserve_officers", 0))

        for action_idx, (action_type, service, priority_mode, delta) in ACTION_DECODE_TABLE.items():

            if action_type == "set_priority_mode":
                if priority_mode == current_priority_mode:
                    mask[action_idx] = False

            elif action_type == "request_missing_documents":
                mask[action_idx] = service in services_with_missing_docs

            elif action_type == "escalate_service":
                mask[action_idx] = (
                    escalation_budget > 0
                    and service in services_with_escalatable
                )

            elif action_type == "advance_time":
                mask[action_idx] = True

            elif action_type == "reallocate_officers":
                has_source = (allocations.get(service, 0) > 0) and (service in active_services)
                has_target = any(svc != service for svc in active_services)
                mask[action_idx] = has_source and has_target

            elif action_type == "assign_capacity":
                if idle_officers <= 0:
                    mask[action_idx] = False
                elif service == "__most_loaded__":
                    mask[action_idx] = len(active_services) > 0
                elif service == "__most_urgent__":
                    mask[action_idx] = any(
                        getattr(snap, "urgent_cases", getattr(snap, "urgent_pending", 0)) > 0 for snap in queue_snaps
                    )
                else:
                    mask[action_idx] = False

        # Guarantee at least one safe action for MaskablePPO.
        if not mask.any():
            mask[18] = True

        return mask


def compute_mask(obs: ObservationModel, current_priority_mode: str = "balanced") -> np.ndarray:
    """Module-level convenience function."""
    return ActionMaskComputer().compute(obs, current_priority_mode)