Gov_Workflow_RL / app /baselines.py
Siddharaj Shirke
deploy: clean code-only snapshot for HF Space
df97e68
from __future__ import annotations
from collections.abc import Callable
from types import SimpleNamespace
from app.env import GovWorkflowEnv
from app.graders import grade_episode
from app.models import ActionModel, ActionType, ObservationModel, PriorityMode, ServiceType
PolicyFn = Callable[[ObservationModel], ActionModel]
def _snapshots(obs: ObservationModel):
"""Return queue snapshots as a list regardless of Phase 1 (list) or Phase 2 (dict)."""
qs = obs.queue_snapshots
if isinstance(qs, dict):
return list(qs.values())
return list(qs)
def _service_attr(q, *attrs):
"""Return the first attribute that exists on a QueueSnapshot (Phase 1 vs Phase 2 names)."""
for attr in attrs:
val = getattr(q, attr, None)
if val is not None:
return val
return 0
def _service_name(q) -> ServiceType:
"""Return ServiceType regardless of Phase 1 (.service) or Phase 2 (.service_type)."""
return getattr(q, "service_type", None) or getattr(q, "service", None)
def _service_with_max(obs: ObservationModel, *attrs) -> ServiceType | None:
snaps = _snapshots(obs)
ranked = sorted(snaps, key=lambda s: _service_attr(s, *attrs), reverse=True)
if ranked and _service_attr(ranked[0], *attrs) > 0:
return _service_name(ranked[0])
return None
def _reserve_officers(obs: ObservationModel) -> int:
pool = obs.officer_pool
# Phase 2: idle_officers property
if hasattr(pool, "idle_officers"):
return int(pool.idle_officers)
# Phase 1 fallback
return int(getattr(pool, "reserve_officers", 0))
def _alloc_for(obs: ObservationModel, service: ServiceType) -> int:
pool = obs.officer_pool
# Phase 2 uses 'allocated'; Phase 1 used 'allocations'
alloc_dict = getattr(pool, "allocated", None) or getattr(pool, "allocations", {})
raw = alloc_dict.get(service)
if raw is None:
raw = alloc_dict.get(service.value if hasattr(service, "value") else str(service), 0)
return int(raw or 0)
def urgent_first_policy(obs: ObservationModel) -> ActionModel:
target = _service_with_max(obs, "urgent_pending", "urgent_cases")
if target:
return ActionModel(action_type=ActionType.REQUEST_MISSING_DOCUMENTS, service_target=target)
return ActionModel(action_type=ActionType.ADVANCE_TIME)
def oldest_first_policy(obs: ObservationModel) -> ActionModel:
return ActionModel(action_type=ActionType.ADVANCE_TIME)
def backlog_clearance_policy(obs: ObservationModel) -> ActionModel:
snaps = _snapshots(obs)
# Assign idle officers to the most backlogged service
if _reserve_officers(obs) > 0:
target = _service_with_max(obs, "total_pending", "active_cases")
if target:
return ActionModel(
action_type=ActionType.ASSIGN_CAPACITY,
service_target=target,
capacity_assignment={target.value: 1},
)
# Clear missing-doc bottlenecks
target = _service_with_max(obs, "blocked_missing_docs", "missing_docs_cases")
if target:
return ActionModel(action_type=ActionType.REQUEST_MISSING_DOCUMENTS, service_target=target)
# Reallocate from least-loaded to most-loaded
if len(snaps) >= 2:
hot = sorted(snaps, key=lambda s: _service_attr(s, "total_pending", "active_cases"), reverse=True)
cold = sorted(snaps, key=lambda s: _service_attr(s, "total_pending", "active_cases"))
hot_svc = _service_name(hot[0])
cold_svc = _service_name(cold[0])
hot_load = _service_attr(hot[0], "total_pending", "active_cases")
cold_load = _service_attr(cold[0], "total_pending", "active_cases")
if (
hot_svc and cold_svc and hot_svc != cold_svc
and hot_load - cold_load >= 3
and _alloc_for(obs, cold_svc) > 1
):
return ActionModel(
action_type=ActionType.REALLOCATE_OFFICERS,
service_target=cold_svc,
reallocation_delta={cold_svc.value: -1, hot_svc.value: 1},
)
return ActionModel(action_type=ActionType.ADVANCE_TIME)
def greedy_sla_policy(obs: ObservationModel) -> ActionModel:
"""SLA-focused fallback policy used by historical aliases."""
target = _service_with_max(obs, "urgent_pending", "urgent_cases", "breached_cases")
if target:
return ActionModel(action_type=ActionType.REQUEST_MISSING_DOCUMENTS, service_target=target)
return backlog_clearance_policy(obs)
def random_policy(obs: ObservationModel) -> ActionModel:
import random
return ActionModel(action_type=ActionType.ADVANCE_TIME)
urgent_first_policy = greedy_sla_policy
fairness_aware_policy = backlog_clearance_policy
POLICIES: dict[str, PolicyFn] = {
"urgent_first": greedy_sla_policy,
"oldest_first": oldest_first_policy,
"backlog_clearance": backlog_clearance_policy,
"random_policy": random_policy,
"greedy_sla_policy": greedy_sla_policy,
"fairness_aware_policy": fairness_aware_policy,
}
def run_policy_episode(task_id: str, policy_name: str, seed: int | None = None, max_steps: int = 500) -> dict:
env = GovWorkflowEnv(task_id=task_id)
obs, _ = env.reset(seed=seed)
policy = POLICIES[policy_name]
reward_sum = 0.0
for _ in range(max_steps):
action = policy(obs)
obs, reward, terminated, truncated, _ = env.step(action)
reward_sum += reward
if terminated or truncated:
break
state = env.state()
grade = grade_episode(state)
# Return a SimpleNamespace so attribute access (result.score) works in main.py
return SimpleNamespace(
task_id=task_id,
policy=policy_name,
seed=state.seed,
reward_sum=round(reward_sum, 4),
score=float(grade.score),
grader=grade.grader_name,
metrics=grade.metrics,
steps=int(state.total_steps),
completed=int(state.total_completed),
backlog=int(state.total_backlog),
)