Spaces:
Running
Running
| """ | |
| Gymnasium adapter for GovWorkflowEnv. | |
| Key contract: | |
| observation_space : Box(OBS_DIM,) float32 | |
| action_space : Discrete(N_ACTIONS) | |
| action_masks() : np.ndarray[bool, N_ACTIONS] | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Callable, Optional | |
| import gymnasium as gym | |
| import numpy as np | |
| from gymnasium import spaces | |
| from app.env import GovWorkflowEnv | |
| from app.models import ActionModel, ActionType, ObservationModel, PriorityMode, ServiceType | |
| from rl.action_mask import ActionMaskComputer | |
| from rl.feature_builder import ACTION_DECODE_TABLE, N_ACTIONS, OBS_DIM, FeatureBuilder | |
| class GovWorkflowGymEnv(gym.Env): | |
| metadata = {"render_modes": []} | |
| def __init__( | |
| self, | |
| task_id: str = "district_backlog_easy", | |
| seed: int = 42, | |
| hard_action_mask: bool = False, | |
| max_non_advance_streak: int = 3, | |
| ): | |
| super().__init__() | |
| self.task_id = task_id | |
| self._seed = seed | |
| self._task_sampler: Optional[Callable[[], str]] = None | |
| self._global_step_counter: Optional[list[int]] = None | |
| self._hard_action_mask: bool = bool(hard_action_mask) | |
| self._max_non_advance_streak = max(0, int(max_non_advance_streak)) | |
| self._non_advance_streak = 0 | |
| self._core_env = GovWorkflowEnv() | |
| self._fb = FeatureBuilder() | |
| self._amc = ActionMaskComputer() | |
| self.observation_space = spaces.Box( | |
| low=0.0, | |
| high=1.0, | |
| shape=(OBS_DIM,), | |
| dtype=np.float32, | |
| ) | |
| self.action_space = spaces.Discrete(N_ACTIONS) | |
| self._current_obs: Optional[ObservationModel] = None | |
| self._current_pm: str = "balanced" | |
| self._last_at: str = "advance_time" | |
| def core_env(self) -> GovWorkflowEnv: | |
| return self._core_env | |
| def set_hard_action_mask(self, enabled: bool) -> None: | |
| self._hard_action_mask = bool(enabled) | |
| def set_task_sampler( | |
| self, | |
| task_sampler: Optional[Callable[[], str]], | |
| global_step_counter: Optional[list[int]] = None, | |
| ) -> None: | |
| self._task_sampler = task_sampler | |
| self._global_step_counter = global_step_counter | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| options: Optional[dict] = None, | |
| ) -> tuple[np.ndarray, dict]: | |
| super().reset(seed=seed) | |
| if self._task_sampler is not None: | |
| self.task_id = self._task_sampler() | |
| use_seed = seed if seed is not None else self._seed | |
| task_opts = {"task_id": self.task_id} | |
| if options: | |
| task_opts.update(options) | |
| obs_model, info = self._core_env.reset(seed=use_seed, options=task_opts) | |
| self._current_obs = obs_model | |
| self._current_pm = "balanced" | |
| self._last_at = "advance_time" | |
| self._non_advance_streak = 0 | |
| info_dict = info.model_dump() if hasattr(info, "model_dump") else info | |
| if not isinstance(info_dict, dict): | |
| try: | |
| info_dict = dict(info_dict) | |
| except (TypeError, ValueError): | |
| info_dict = {} | |
| info_dict["fairness_gap"] = self._obs_fairness_gap(obs_model) | |
| return self._to_array(obs_model), info_dict | |
| def step(self, action: int) -> tuple[np.ndarray, float, bool, bool, dict]: | |
| requested_action_idx = int(action) | |
| action_idx = requested_action_idx | |
| if self._hard_action_mask and self._current_obs is not None: | |
| action_idx = self._sanitize_action_idx(requested_action_idx, self.action_masks()) | |
| action_model = self._decode_action(action_idx) | |
| obs_model, reward, terminated, truncated, info = self._core_env.step(action_model) | |
| if self._global_step_counter is not None: | |
| self._global_step_counter[0] += 1 | |
| self._current_obs = obs_model | |
| self._last_at = action_model.action_type.value | |
| if getattr(action_model, "priority_mode", None) is not None: | |
| self._current_pm = action_model.priority_mode.value | |
| if action_model.action_type == ActionType.ADVANCE_TIME: | |
| self._non_advance_streak = 0 | |
| else: | |
| self._non_advance_streak += 1 | |
| info_dict = info.model_dump() if hasattr(info, "model_dump") else info | |
| if not isinstance(info_dict, dict): | |
| try: | |
| info_dict = dict(info_dict) | |
| except (TypeError, ValueError): | |
| info_dict = {} | |
| info_dict["fairness_gap"] = self._obs_fairness_gap(obs_model) | |
| info_dict["requested_action_idx"] = requested_action_idx | |
| info_dict["executed_action_idx"] = action_idx | |
| info_dict["action_mask_applied"] = bool(action_idx != requested_action_idx) | |
| return self._to_array(obs_model), float(reward), terminated, truncated, info_dict | |
| def action_masks(self) -> np.ndarray: | |
| if self._current_obs is None: | |
| return np.ones(N_ACTIONS, dtype=bool) | |
| mask = self._amc.compute(self._current_obs, self._current_pm) | |
| if self._max_non_advance_streak > 0 and self._non_advance_streak >= self._max_non_advance_streak: | |
| forced = np.zeros(N_ACTIONS, dtype=bool) | |
| forced[18] = True | |
| return forced | |
| return mask | |
| def render(self) -> None: | |
| return None | |
| def _to_array(self, obs: ObservationModel) -> np.ndarray: | |
| return self._fb.build(obs, self._current_pm, self._last_at) | |
| def _queue_snapshot_iter(self) -> list[Any]: | |
| if self._current_obs is None: | |
| return [] | |
| raw = getattr(self._current_obs, "queue_snapshots", []) | |
| if isinstance(raw, dict): | |
| return list(raw.values()) | |
| if isinstance(raw, list): | |
| return list(raw) | |
| try: | |
| return list(raw) | |
| except Exception: | |
| return [] | |
| def _queue_service(self, snap: Any) -> Optional[ServiceType]: | |
| value = getattr(snap, "service_type", None) or getattr(snap, "service", None) | |
| if value is None: | |
| return None | |
| if isinstance(value, ServiceType): | |
| return value | |
| try: | |
| return ServiceType(str(value)) | |
| except Exception: | |
| return None | |
| def _queue_active_cases(self, snap: Any) -> int: | |
| return int(getattr(snap, "total_pending", getattr(snap, "active_cases", 0)) or 0) | |
| def _queue_urgent_cases(self, snap: Any) -> int: | |
| return int(getattr(snap, "urgent_pending", getattr(snap, "urgent_cases", 0)) or 0) | |
| def _obs_fairness_gap(self, obs: ObservationModel) -> float: | |
| """ | |
| Canonical fairness signal for RL info payload. | |
| Current ObservationModel exposes fairness as `fairness_index`, while | |
| episode-level grading uses `fairness_gap` from EpisodeStateModel. | |
| Keep backward-compatible fallback to avoid runtime breaks. | |
| """ | |
| return float(getattr(obs, "fairness_gap", getattr(obs, "fairness_index", 0.0)) or 0.0) | |
| def _build_action_model(self, action_type: ActionType, **kwargs: Any) -> ActionModel: | |
| service = kwargs.get("service") | |
| target_service = kwargs.get("target_service") | |
| officer_delta = int(kwargs.get("officer_delta", 1) or 1) | |
| priority_mode = kwargs.get("priority_mode") | |
| candidates: list[dict[str, Any]] = [] | |
| if action_type == ActionType.ADVANCE_TIME: | |
| candidates.append({"action_type": action_type}) | |
| elif action_type == ActionType.SET_PRIORITY_MODE: | |
| candidates.append({"action_type": action_type, "priority_mode": priority_mode}) | |
| elif action_type == ActionType.ASSIGN_CAPACITY and service is not None: | |
| candidates.extend( | |
| [ | |
| {"action_type": action_type, "service": service, "officer_delta": officer_delta}, | |
| {"action_type": action_type, "service_target": service, "officer_delta": officer_delta}, | |
| {"action_type": action_type, "capacity_assignment": {service.value: officer_delta}}, | |
| ] | |
| ) | |
| elif action_type == ActionType.REQUEST_MISSING_DOCUMENTS and service is not None: | |
| candidates.extend( | |
| [ | |
| {"action_type": action_type, "service": service}, | |
| {"action_type": action_type, "service_target": service}, | |
| ] | |
| ) | |
| elif action_type == ActionType.ESCALATE_SERVICE and service is not None: | |
| candidates.extend( | |
| [ | |
| {"action_type": action_type, "service": service}, | |
| {"action_type": action_type, "service_target": service}, | |
| {"action_type": action_type, "escalation_target": service}, | |
| ] | |
| ) | |
| elif action_type == ActionType.REALLOCATE_OFFICERS and service is not None and target_service is not None: | |
| candidates.extend( | |
| [ | |
| { | |
| "action_type": action_type, | |
| "service": service, | |
| "target_service": target_service, | |
| "officer_delta": officer_delta, | |
| }, | |
| { | |
| "action_type": action_type, | |
| "reallocation_delta": { | |
| service.value: -officer_delta, | |
| target_service.value: officer_delta, | |
| }, | |
| }, | |
| ] | |
| ) | |
| for candidate in candidates: | |
| try: | |
| return ActionModel(**candidate) | |
| except Exception: | |
| continue | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME) | |
| def _decode_action(self, action_idx: int) -> ActionModel: | |
| if action_idx not in ACTION_DECODE_TABLE: | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME) | |
| action_type_str, service_str, priority_mode_str, delta = ACTION_DECODE_TABLE[action_idx] | |
| action_type = ActionType(action_type_str) | |
| if action_type == ActionType.SET_PRIORITY_MODE and priority_mode_str is not None: | |
| return self._build_action_model( | |
| action_type, | |
| priority_mode=PriorityMode(priority_mode_str), | |
| ) | |
| if action_type == ActionType.ASSIGN_CAPACITY: | |
| if service_str == "__most_loaded__": | |
| target = self._find_most_loaded_service() | |
| elif service_str == "__most_urgent__": | |
| target = self._find_most_urgent_service() | |
| else: | |
| target = ServiceType(service_str) if service_str and not service_str.startswith("__") else None | |
| if target is None: | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME) | |
| return self._build_action_model( | |
| action_type, | |
| service=target, | |
| officer_delta=max(int(delta or 1), 1), | |
| ) | |
| if action_type == ActionType.REQUEST_MISSING_DOCUMENTS: | |
| target = ServiceType(service_str) if service_str and not service_str.startswith("__") else self._find_most_loaded_service() | |
| if target is None: | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME) | |
| return self._build_action_model(action_type, service=target) | |
| if action_type == ActionType.ESCALATE_SERVICE: | |
| target = ServiceType(service_str) if service_str and not service_str.startswith("__") else self._find_most_urgent_service() | |
| if target is None: | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME) | |
| return self._build_action_model(action_type, service=target) | |
| if action_type == ActionType.REALLOCATE_OFFICERS: | |
| source = ServiceType(service_str) | |
| target = self._find_reallocation_target(source) | |
| if target is None: | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME) | |
| return self._build_action_model( | |
| action_type, | |
| service=source, | |
| target_service=target, | |
| officer_delta=1, | |
| ) | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME) | |
| def _find_most_loaded_service(self) -> Optional[ServiceType]: | |
| snaps = self._queue_snapshot_iter() | |
| if not snaps: | |
| return None | |
| best = max(snaps, key=self._queue_active_cases) | |
| return self._queue_service(best) | |
| def _find_most_urgent_service(self) -> Optional[ServiceType]: | |
| snaps = [snap for snap in self._queue_snapshot_iter() if self._queue_urgent_cases(snap) > 0] | |
| if not snaps: | |
| return None | |
| best = max(snaps, key=lambda snap: (self._queue_urgent_cases(snap), self._queue_active_cases(snap))) | |
| return self._queue_service(best) | |
| def _find_reallocation_target(self, source: ServiceType) -> Optional[ServiceType]: | |
| snaps = [snap for snap in self._queue_snapshot_iter() if self._queue_service(snap) != source] | |
| if not snaps: | |
| return None | |
| best = max(snaps, key=self._queue_active_cases) | |
| if self._queue_active_cases(best) <= 0: | |
| return None | |
| return self._queue_service(best) | |
| def _sanitize_action_idx(self, action_idx: int, masks: np.ndarray) -> int: | |
| if 0 <= action_idx < N_ACTIONS and bool(masks[action_idx]): | |
| return action_idx | |
| if 0 <= 18 < N_ACTIONS and bool(masks[18]): | |
| return 18 | |
| valid = np.flatnonzero(masks) | |
| if valid.size == 0: | |
| return 18 | |
| return int(valid[0]) | |