Spaces:
Running
Running
| """ | |
| Converts ObservationModel (Pydantic) β flat numpy float32 vector. | |
| All downstream RL code depends on OBS_DIM being stable. | |
| Feature layout (total = OBS_DIM = 84): | |
| [0 : 63) β per-service block (7 services Γ 9 features each) | |
| [63 : 84) β global block (21 scalar features) | |
| """ | |
| from __future__ import annotations | |
| import numpy as np | |
| from typing import List | |
| from app.models import ( | |
| ObservationModel, | |
| ServiceType, | |
| StageType, | |
| PriorityMode, | |
| ActionType, | |
| ) | |
| # ββ Canonical orderings (must never change across the codebase) ββββββββββββββ | |
| SERVICES: List[ServiceType] = [ | |
| ServiceType.PASSPORT, | |
| ServiceType.DRIVING_LICENSE, | |
| ServiceType.GST_REGISTRATION, | |
| ServiceType.INCOME_CERTIFICATE, | |
| ServiceType.CASTE_CERTIFICATE, | |
| ServiceType.BIRTH_CERTIFICATE, | |
| ServiceType.LAND_REGISTRATION, | |
| ] # 7 legacy RL services for checkpoint compatibility | |
| STAGES: List[StageType] = list(StageType) # 5 stages | |
| PRIORITY_MODES: List[PriorityMode] = list(PriorityMode) # 4 modes | |
| ACTION_TYPES: List[ActionType] = list(ActionType) # 6 types | |
| SERVICE_IDX = {s: i for i, s in enumerate(SERVICES)} | |
| STAGE_IDX = {s: i for i, s in enumerate(STAGES)} | |
| PM_IDX = {m: i for i, m in enumerate(PRIORITY_MODES)} | |
| AT_IDX = {a: i for i, a in enumerate(ACTION_TYPES)} | |
| # ββ Dimension constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| N_SERVICES = len(SERVICES) # 7 | |
| N_STAGES = len(STAGES) # 5 | |
| N_PRIORITY_MODES = len(PRIORITY_MODES) # 4 | |
| N_ACTION_TYPES = len(ACTION_TYPES) # 6 | |
| PER_SERVICE_DIM = 4 + N_STAGES # queue_len, avg_wait, urgent, missing + 5 stage fracs = 9 | |
| GLOBAL_DIM = ( | |
| 1 # day_ratio | |
| + 1 # total_backlog_normalized | |
| + 1 # total_completed_normalized | |
| + 1 # total_sla_breaches_normalized | |
| + 1 # fairness_gap | |
| + 1 # escalation_budget_ratio | |
| + 1 # last_action_valid | |
| + N_ACTION_TYPES # last_action_type one-hot (6) | |
| + N_PRIORITY_MODES # current_priority_mode one-hot (4) | |
| + 1 # idle_officer_ratio | |
| + 1 # urgent_backlog_ratio | |
| + 1 # officer_utilization | |
| + 1 # backlog_per_officer | |
| ) # = 21 | |
| OBS_DIM = N_SERVICES * PER_SERVICE_DIM + GLOBAL_DIM # 63 + 21 = 84 | |
| # ββ Normalisation caps (avoid div-by-zero, keep values in [0,1]) βββββββββββββ | |
| _MAX_QUEUE = 200.0 | |
| _MAX_WAIT = 30.0 | |
| _MAX_URGENT = 50.0 | |
| _MAX_MISSING = 50.0 | |
| _MAX_BACKLOG = 500.0 | |
| _MAX_COMPLETED = 500.0 | |
| _MAX_SLA = 100.0 | |
| _MAX_ESC_BUDGET = 20.0 | |
| _MAX_OFFICERS = 50.0 | |
| class FeatureBuilder: | |
| """ | |
| Stateless transformer: ObservationModel β np.ndarray[float32, OBS_DIM]. | |
| Usage: | |
| fb = FeatureBuilder() | |
| vec = fb.build(obs, current_priority_mode="urgent_first", | |
| last_action_type="advance_time") | |
| """ | |
| def build( | |
| self, | |
| obs: ObservationModel, | |
| current_priority_mode: str = "balanced", | |
| last_action_type: str = "advance_time", | |
| ) -> np.ndarray: | |
| features = np.zeros(OBS_DIM, dtype=np.float32) | |
| offset = 0 | |
| snap_dict = { | |
| snap.service_type: snap | |
| for snap in (obs.queue_snapshots.values() if isinstance(obs.queue_snapshots, dict) else obs.queue_snapshots) | |
| } | |
| # ββ Per-service block βββββββββββββββββββββββββββββββββββββββββββββ | |
| for svc in SERVICES: | |
| snap = snap_dict.get(svc) | |
| if snap is None: | |
| offset += PER_SERVICE_DIM | |
| continue | |
| total_in_svc = max(getattr(snap, "total_pending", getattr(snap, "active_cases", 0)), 1) | |
| features[offset + 0] = getattr(snap, "total_pending", getattr(snap, "active_cases", 0)) / _MAX_QUEUE | |
| features[offset + 1] = getattr(snap, "avg_age_days", getattr(snap, "avg_waiting_days", getattr(snap, "oldest_case_age_days", 0))) / _MAX_WAIT | |
| features[offset + 2] = getattr(snap, "urgent_cases", getattr(snap, "urgent_pending", 0)) / _MAX_URGENT | |
| features[offset + 3] = getattr(snap, "blocked_missing_docs", getattr(snap, "missing_docs_cases", 0)) / _MAX_MISSING | |
| # Stage distribution as fractions | |
| stage_counts = getattr(snap, "stage_counts", getattr(snap, "public_stage_counts", {})) or {} | |
| for stg in STAGES: | |
| count = stage_counts.get(stg, 0) | |
| features[offset + 4 + STAGE_IDX[stg]] = count / total_in_svc | |
| offset += PER_SERVICE_DIM | |
| # ββ Global block ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| day_ratio = obs.day / max(obs.max_days, 1) | |
| features[offset + 0] = day_ratio | |
| features[offset + 1] = obs.total_backlog / _MAX_BACKLOG | |
| features[offset + 2] = obs.total_completed / _MAX_COMPLETED | |
| features[offset + 3] = obs.total_sla_breaches / _MAX_SLA | |
| features[offset + 4] = float(getattr(obs, "fairness_gap", getattr(obs, "fairness_index", 0.0)) or 0.0) | |
| features[offset + 5] = obs.escalation_budget_remaining / _MAX_ESC_BUDGET | |
| features[offset + 6] = float(obs.last_action_valid) | |
| offset += 7 | |
| # Last action type one-hot | |
| at_vec = np.zeros(N_ACTION_TYPES, dtype=np.float32) | |
| try: | |
| at_vec[AT_IDX[ActionType(last_action_type)]] = 1.0 | |
| except (ValueError, KeyError): | |
| pass | |
| features[offset: offset + N_ACTION_TYPES] = at_vec | |
| offset += N_ACTION_TYPES | |
| # Current priority mode one-hot | |
| pm_vec = np.zeros(N_PRIORITY_MODES, dtype=np.float32) | |
| try: | |
| pm_vec[PM_IDX[PriorityMode(current_priority_mode)]] = 1.0 | |
| except (ValueError, KeyError): | |
| pass | |
| features[offset: offset + N_PRIORITY_MODES] = pm_vec | |
| offset += N_PRIORITY_MODES | |
| # Officer-derived scalars | |
| pool = obs.officer_pool | |
| total_officers = max(getattr(pool, "total_officers", 1) if not callable(getattr(pool, "total_officers", None)) else pool.total_officers(), 1) | |
| idle_officers = getattr(pool, "idle_officers", getattr(pool, "reserve_officers", 0)) | |
| idle_ratio = idle_officers / total_officers | |
| total_backlog_safe = max(obs.total_backlog, 1) | |
| urgent_total = sum( | |
| getattr(snap_dict[s], "urgent_cases", getattr(snap_dict[s], "urgent_pending", 0)) | |
| for s in SERVICES | |
| if s in snap_dict | |
| ) | |
| urgent_ratio = urgent_total / total_backlog_safe | |
| utilization = (total_officers - idle_officers) / total_officers | |
| backlog_per_off = obs.total_backlog / total_officers | |
| features[offset + 0] = float(np.clip(idle_ratio, 0.0, 1.0)) | |
| features[offset + 1] = float(np.clip(urgent_ratio, 0.0, 1.0)) | |
| features[offset + 2] = float(np.clip(utilization, 0.0, 1.0)) | |
| features[offset + 3] = float(np.clip(backlog_per_off / _MAX_OFFICERS, 0.0, 1.0)) | |
| assert offset + 4 == OBS_DIM, f"OBS_DIM mismatch: {offset + 4} != {OBS_DIM}" | |
| return features | |
| # -- Action space layout (N_ACTIONS = 28) ------------------------------------- | |
| # | |
| # 0 - 3 : set_priority_mode (4 modes in PRIORITY_MODES order) | |
| # 4 - 10 : request_missing_documents per service (7) | |
| # 11 - 17 : escalate_service per service (7) | |
| # 18 : advance_time | |
| # 19 - 25 : reallocate_officers from source service -> most loaded other service | |
| # 26 : assign_capacity +1 to most-loaded service | |
| # 27 : assign_capacity +1 to most-urgent service | |
| N_ACTIONS = 4 + N_SERVICES + N_SERVICES + 1 + N_SERVICES + 2 # = 28 | |
| ACTION_DECODE_TABLE = {} | |
| idx = 0 | |
| for m in PRIORITY_MODES: | |
| ACTION_DECODE_TABLE[idx] = ("set_priority_mode", None, m.value, None) | |
| idx += 1 | |
| for s in SERVICES: | |
| ACTION_DECODE_TABLE[idx] = ("request_missing_documents", s.value, None, None) | |
| idx += 1 | |
| for s in SERVICES: | |
| ACTION_DECODE_TABLE[idx] = ("escalate_service", s.value, None, None) | |
| idx += 1 | |
| ACTION_DECODE_TABLE[idx] = ("advance_time", None, None, None); idx += 1 | |
| for s in SERVICES: | |
| ACTION_DECODE_TABLE[idx] = ("reallocate_officers", s.value, "most_loaded_other", 1) | |
| idx += 1 | |
| ACTION_DECODE_TABLE[idx] = ("assign_capacity", "__most_loaded__", None, 1); idx += 1 | |
| ACTION_DECODE_TABLE[idx] = ("assign_capacity", "__most_urgent__", None, 1); idx += 1 | |
| assert len(ACTION_DECODE_TABLE) == N_ACTIONS | |