Spaces:
Running
Running
| from __future__ import annotations | |
| from collections.abc import Callable | |
| from types import SimpleNamespace | |
| from app.env import GovWorkflowEnv | |
| from app.graders import grade_episode | |
| from app.models import ActionModel, ActionType, ObservationModel, PriorityMode, ServiceType | |
| PolicyFn = Callable[[ObservationModel], ActionModel] | |
| def _snapshots(obs: ObservationModel): | |
| """Return queue snapshots as a list regardless of Phase 1 (list) or Phase 2 (dict).""" | |
| qs = obs.queue_snapshots | |
| if isinstance(qs, dict): | |
| return list(qs.values()) | |
| return list(qs) | |
| def _service_attr(q, *attrs): | |
| """Return the first attribute that exists on a QueueSnapshot (Phase 1 vs Phase 2 names).""" | |
| for attr in attrs: | |
| val = getattr(q, attr, None) | |
| if val is not None: | |
| return val | |
| return 0 | |
| def _service_name(q) -> ServiceType: | |
| """Return ServiceType regardless of Phase 1 (.service) or Phase 2 (.service_type).""" | |
| return getattr(q, "service_type", None) or getattr(q, "service", None) | |
| def _service_with_max(obs: ObservationModel, *attrs) -> ServiceType | None: | |
| snaps = _snapshots(obs) | |
| ranked = sorted(snaps, key=lambda s: _service_attr(s, *attrs), reverse=True) | |
| if ranked and _service_attr(ranked[0], *attrs) > 0: | |
| return _service_name(ranked[0]) | |
| return None | |
| def _reserve_officers(obs: ObservationModel) -> int: | |
| pool = obs.officer_pool | |
| # Phase 2: idle_officers property | |
| if hasattr(pool, "idle_officers"): | |
| return int(pool.idle_officers) | |
| # Phase 1 fallback | |
| return int(getattr(pool, "reserve_officers", 0)) | |
| def _alloc_for(obs: ObservationModel, service: ServiceType) -> int: | |
| pool = obs.officer_pool | |
| # Phase 2 uses 'allocated'; Phase 1 used 'allocations' | |
| alloc_dict = getattr(pool, "allocated", None) or getattr(pool, "allocations", {}) | |
| raw = alloc_dict.get(service) | |
| if raw is None: | |
| raw = alloc_dict.get(service.value if hasattr(service, "value") else str(service), 0) | |
| return int(raw or 0) | |
| def urgent_first_policy(obs: ObservationModel) -> ActionModel: | |
| target = _service_with_max(obs, "urgent_pending", "urgent_cases") | |
| if target: | |
| return ActionModel(action_type=ActionType.REQUEST_MISSING_DOCUMENTS, service_target=target) | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME) | |
| def oldest_first_policy(obs: ObservationModel) -> ActionModel: | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME) | |
| def backlog_clearance_policy(obs: ObservationModel) -> ActionModel: | |
| snaps = _snapshots(obs) | |
| # Assign idle officers to the most backlogged service | |
| if _reserve_officers(obs) > 0: | |
| target = _service_with_max(obs, "total_pending", "active_cases") | |
| if target: | |
| return ActionModel( | |
| action_type=ActionType.ASSIGN_CAPACITY, | |
| service_target=target, | |
| capacity_assignment={target.value: 1}, | |
| ) | |
| # Clear missing-doc bottlenecks | |
| target = _service_with_max(obs, "blocked_missing_docs", "missing_docs_cases") | |
| if target: | |
| return ActionModel(action_type=ActionType.REQUEST_MISSING_DOCUMENTS, service_target=target) | |
| # Reallocate from least-loaded to most-loaded | |
| if len(snaps) >= 2: | |
| hot = sorted(snaps, key=lambda s: _service_attr(s, "total_pending", "active_cases"), reverse=True) | |
| cold = sorted(snaps, key=lambda s: _service_attr(s, "total_pending", "active_cases")) | |
| hot_svc = _service_name(hot[0]) | |
| cold_svc = _service_name(cold[0]) | |
| hot_load = _service_attr(hot[0], "total_pending", "active_cases") | |
| cold_load = _service_attr(cold[0], "total_pending", "active_cases") | |
| if ( | |
| hot_svc and cold_svc and hot_svc != cold_svc | |
| and hot_load - cold_load >= 3 | |
| and _alloc_for(obs, cold_svc) > 1 | |
| ): | |
| return ActionModel( | |
| action_type=ActionType.REALLOCATE_OFFICERS, | |
| service_target=cold_svc, | |
| reallocation_delta={cold_svc.value: -1, hot_svc.value: 1}, | |
| ) | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME) | |
| def greedy_sla_policy(obs: ObservationModel) -> ActionModel: | |
| """SLA-focused fallback policy used by historical aliases.""" | |
| target = _service_with_max(obs, "urgent_pending", "urgent_cases", "breached_cases") | |
| if target: | |
| return ActionModel(action_type=ActionType.REQUEST_MISSING_DOCUMENTS, service_target=target) | |
| return backlog_clearance_policy(obs) | |
| def random_policy(obs: ObservationModel) -> ActionModel: | |
| import random | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME) | |
| urgent_first_policy = greedy_sla_policy | |
| fairness_aware_policy = backlog_clearance_policy | |
| POLICIES: dict[str, PolicyFn] = { | |
| "urgent_first": greedy_sla_policy, | |
| "oldest_first": oldest_first_policy, | |
| "backlog_clearance": backlog_clearance_policy, | |
| "random_policy": random_policy, | |
| "greedy_sla_policy": greedy_sla_policy, | |
| "fairness_aware_policy": fairness_aware_policy, | |
| } | |
| def run_policy_episode(task_id: str, policy_name: str, seed: int | None = None, max_steps: int = 500) -> dict: | |
| env = GovWorkflowEnv(task_id=task_id) | |
| obs, _ = env.reset(seed=seed) | |
| policy = POLICIES[policy_name] | |
| reward_sum = 0.0 | |
| for _ in range(max_steps): | |
| action = policy(obs) | |
| obs, reward, terminated, truncated, _ = env.step(action) | |
| reward_sum += reward | |
| if terminated or truncated: | |
| break | |
| state = env.state() | |
| grade = grade_episode(state) | |
| # Return a SimpleNamespace so attribute access (result.score) works in main.py | |
| return SimpleNamespace( | |
| task_id=task_id, | |
| policy=policy_name, | |
| seed=state.seed, | |
| reward_sum=round(reward_sum, 4), | |
| score=float(grade.score), | |
| grader=grade.grader_name, | |
| metrics=grade.metrics, | |
| steps=int(state.total_steps), | |
| completed=int(state.total_completed), | |
| backlog=int(state.total_backlog), | |
| ) | |