""" env.py — Gov Workflow OpenEnv Gymnasium/OpenEnv-compatible environment aligned with Phase 1 schemas. """ from __future__ import annotations import random from uuid import uuid4 from app.event_engine import EventEngine from app.models import ( ActionModel, ActionType, ApplicationCase, EpisodeStateModel, InternalSubstate, ObservationModel, OfficerPool, PriorityMode, QueueSnapshot, RewardModel, ScenarioMode, ServiceType, StepInfoModel, TaskConfig, ) from app.reward import compute_reward from app.signal_computer import SignalComputer from app.engine import DayResult, DaySimulator from app.tasks import get_task def completion_fairness_gap( arrived_by_service: dict[ServiceType, int], completed_by_service: dict[ServiceType, int], ) -> float: services = list(arrived_by_service.keys()) if len(services) < 2: return 0.0 rates = [] for svc in services: arrived = max(1, arrived_by_service.get(svc, 0)) completed = completed_by_service.get(svc, 0) rates.append(completed / arrived) return max(rates) - min(rates) if rates else 0.0 class EpisodeMetrics: def __init__(self): self.total_arrived: int = 0 self.total_completed: int = 0 self.total_sla_breaches: int = 0 self.total_rejected: int = 0 self.total_invalid_actions: int = 0 self.total_escalations_used: int = 0 self.total_wasted_escalations: int = 0 self.total_docs_requested: int = 0 self.total_docs_cleared: int = 0 self.total_idle_officer_days: int = 0 self.total_capacity_days: int = 0 self.total_urgent_arrived: int = 0 self.total_urgent_completed: int = 0 self.cumulative_reward: float = 0.0 def to_reward_model(self) -> RewardModel: return RewardModel(total_reward=self.cumulative_reward) class GovWorkflowEnv: def __init__(self, task_id: str = "district_backlog_easy", seed: int | None = None) -> None: self.task_id = task_id self.task: TaskConfig = get_task(task_id) self.seed = seed self.max_steps_per_episode = max(1, int(self.task.max_days) * 10) self._init_episode_state() def reset( self, seed: int | None = None, options: dict | None = None, ) -> tuple[ObservationModel, dict]: task_id = (options or {}).get("task_id", self.task_id) self.task = get_task(task_id) self.task_id = self.task.task_id self.seed = self.task.seed if seed is None else int(seed) self.rng = random.Random(self.seed) max_steps_override = (options or {}).get("max_steps_per_episode") if max_steps_override is None: self.max_steps_per_episode = max(1, int(self.task.max_days) * 10) else: self.max_steps_per_episode = max(1, int(max_steps_override)) self.episode_id = f"{self.task_id}-s{self.seed}-{uuid4().hex[:6]}" self.day = 0 self.total_steps = 0 self.terminated = False self.truncated = False self.priority_mode = PriorityMode.BALANCED pool = self.task.initial_officer_pool self.officer_pool = OfficerPool( total_officers=pool.total_officers, available_officers=pool.available_officers, allocated=dict(pool.allocated), pending_reallocation=dict(getattr(pool, "pending_reallocation", {})), ) self.active_cases: list[ApplicationCase] = [] self.completed_cases: list[ApplicationCase] = [] self.escalation_budget_remaining = self.task.escalation_budget self.arrived_by_service = {s: 0 for s in self.task.enabled_services} self.completed_by_service = {s: 0 for s in self.task.enabled_services} self.metrics = EpisodeMetrics() self.action_history: list[dict] = [] self.last_action_valid = True self.last_action_message = "reset" self.last_action_explanation = "" self.event_engine = EventEngine( seed=self.seed, scenario_mode=self.task.scenario_mode, ) self.simulator = DaySimulator( task_config=self.task, rng=self.rng, event_engine=self.event_engine, ) self.signal_computer = SignalComputer() obs = self._build_observation(active_events=[]) info = { "task_id": self.task_id, "seed": self.seed, "episode_id": self.episode_id, "max_days": self.task.max_days, } return obs, info def step( self, action: ActionModel | dict, ) -> tuple[ObservationModel, float, bool, bool, StepInfoModel]: if isinstance(action, dict): from app.models import ActionModel action = ActionModel(**action) if self.terminated or self.truncated: raise RuntimeError("Episode ended — call reset() before stepping.") self.total_steps += 1 invalid_action = False day_result = DayResult() try: notes, day_result = self._apply_action(action, day_result) self.last_action_valid = True self.last_action_message = notes[-1] if notes else "ok" self.last_action_explanation = self.last_action_message except ValueError as exc: invalid_action = True self.metrics.total_invalid_actions += 1 self.last_action_valid = False self.last_action_message = str(exc) self.last_action_explanation = f"Invalid: {exc}" fairness_gap = completion_fairness_gap( self.arrived_by_service, self.completed_by_service, ) reward: RewardModel = compute_reward( stage_advances=day_result.stage_advances, completions=day_result.new_completions, active_backlog=len(self.active_cases), new_sla_breaches=day_result.new_sla_breaches, fairness_gap=fairness_gap, fairness_threshold=self.task.fairness_threshold or 0.0, invalid_action=invalid_action, idle_capacity=day_result.idle_officer_days, award_stability_bonus=(action.action_type == ActionType.ADVANCE_TIME), ) self.metrics.cumulative_reward += reward.total_reward self.terminated = ( len(self.active_cases) == 0 and self.day > 0 and not invalid_action ) self.truncated = ( (self.day >= self.task.max_days or self.total_steps >= self.max_steps_per_episode) and not self.terminated ) info = StepInfoModel( reward_breakdown=reward, newly_arrived_cases=day_result.new_arrivals, newly_completed_cases=day_result.new_completions, newly_sla_breached_cases=day_result.new_sla_breaches, newly_resolved_doc_cases=day_result.newly_unblocked_missing, invalid_action=invalid_action, action_explanation=self.last_action_explanation, active_events=day_result.active_events, grader_preview_score=0.0, effects_resolved_this_step=[], ) self.action_history.append({ "step": self.total_steps, "day": self.day, "action": action.model_dump(mode="json"), "invalid": invalid_action, "message": self.last_action_message, "reward": reward.total_reward, }) obs = self._build_observation(active_events=day_result.active_events) return obs, reward.total_reward, self.terminated, self.truncated, info def count_pending_effects(self) -> int: """Count all pending delayed effects waiting to resolve.""" if hasattr(self, '_pending_effects') and self._pending_effects: return len(self._pending_effects) if hasattr(self, 'simulator') and hasattr(self.simulator, 'pending_effects'): return len(self.simulator.pending_effects) if hasattr(self, 'pending_effects'): return len(self.pending_effects) return 0 def state(self) -> EpisodeStateModel: fairness_gap = completion_fairness_gap( self.arrived_by_service, self.completed_by_service ) # Compute average waiting days across completed cases avg_wait = ( sum(c.waiting_days for c in self.completed_cases) / len(self.completed_cases) if self.completed_cases else 0.0 ) return EpisodeStateModel( episode_id=self.episode_id, task_id=self.task_id, seed=self.seed, scenario_mode=self.task.scenario_mode, day=self.day, max_days=self.task.max_days, terminated=self.terminated, truncated=self.truncated, total_steps=self.total_steps, total_completed=len(self.completed_cases), total_backlog=len(self.active_cases), total_sla_breaches=self.metrics.total_sla_breaches, total_rejected=self.metrics.total_rejected, action_history_count=len(self.action_history), cumulative_reward=self.metrics.cumulative_reward, officer_pool=self.officer_pool.model_copy(deep=True), pending_effects_count=self.count_pending_effects(), active_events_today=[], # ── Grader-facing fields ────────────────────────────────── fairness_gap=round(fairness_gap, 4), total_arrived=self.metrics.total_arrived, total_docs_requested=self.metrics.total_docs_requested, total_docs_cleared=self.metrics.total_docs_cleared, total_idle_officer_days=self.metrics.total_idle_officer_days, total_capacity_days=self.metrics.total_capacity_days, total_urgent_arrived=self.metrics.total_urgent_arrived, total_urgent_completed=self.metrics.total_urgent_completed, total_escalations_used=self.metrics.total_escalations_used, total_wasted_escalations=self.metrics.total_wasted_escalations, total_invalid_actions=self.metrics.total_invalid_actions, avg_waiting_days=round(avg_wait, 2), # Full action log — populated but stripped by API unless requested action_history=list(self.action_history), ) def _apply_action( self, action: ActionModel, day_result: DayResult, ) -> tuple[list[str], DayResult]: notes: list[str] = [] if action.action_type == ActionType.SET_PRIORITY_MODE: if action.priority_mode is None: raise ValueError("priority_mode required for set_priority_mode") old_mode = self.priority_mode self.priority_mode = action.priority_mode notes.append(f"Priority mode changed: {old_mode.value} -> {action.priority_mode.value}") return notes, day_result if action.action_type == ActionType.ASSIGN_CAPACITY: cap = action.capacity_assignment if not cap: raise ValueError("capacity_assignment dict required for assign_capacity") for svc_key, delta in cap.items(): svc = ServiceType(svc_key) if isinstance(svc_key, str) else svc_key if svc not in self.task.enabled_services: raise ValueError(f"{svc.value} is not enabled in this task") if delta <= 0: raise ValueError("capacity delta must be positive") idle = self.officer_pool.idle_officers if delta > idle: raise ValueError(f"Only {idle} idle officers available; requested {delta}") self.officer_pool.allocated[svc] = self.officer_pool.allocated.get(svc, 0) + delta notes.append(f"Assigned {delta} officer(s) to {svc.value}") return notes, day_result if action.action_type == ActionType.REQUEST_MISSING_DOCUMENTS: svc = action.service_target if svc is None: raise ValueError("service_target required for request_missing_documents") candidates = [ c for c in self.active_cases if c.service_type == svc and c.internal_substate == InternalSubstate.BLOCKED_MISSING_DOCS ] if not candidates: raise ValueError(f"No BLOCKED_MISSING_DOCS cases for {svc.value}") candidates.sort(key=lambda c: (-c.sla_risk, c.arrival_day)) resolved = 0 for case in candidates[:3]: case.doc_request_sent_day = self.day case.doc_resolution_day = self.day + self.rng.randint(2, 3) self.metrics.total_docs_requested += 1 resolved += 1 notes.append(f"Sent missing-doc requests for {resolved} case(s) in {svc.value}") return notes, day_result if action.action_type == ActionType.ESCALATE_SERVICE: if self.escalation_budget_remaining <= 0: self.metrics.total_wasted_escalations += 1 raise ValueError("Escalation budget exhausted") svc = action.escalation_target or action.service_target candidates = [ c for c in self.active_cases if (svc is None or c.service_type == svc) and not c.is_urgent ] if not candidates: self.metrics.total_wasted_escalations += 1 raise ValueError("No eligible non-urgent cases to escalate") best = max(candidates, key=lambda c: (c.sla_risk, -c.arrival_day)) best.is_urgent = True self.escalation_budget_remaining -= 1 self.metrics.total_escalations_used += 1 notes.append(f"Escalated case {best.case_id} ({best.service_type.value})") return notes, day_result if action.action_type == ActionType.ADVANCE_TIME: day_result = self._advance_one_day() notes.append(f"Day {self.day} simulated") return notes, day_result if action.action_type == ActionType.REALLOCATE_OFFICERS: delta = action.reallocation_delta if not delta or len(delta) < 2: raise ValueError("reallocation_delta must have at least 2 entries") total = sum(delta.values()) if total != 0: raise ValueError(f"reallocation_delta must sum to 0 (got {total})") for svc_key, change in delta.items(): svc = ServiceType(svc_key) if isinstance(svc_key, str) else svc_key if svc not in self.task.enabled_services: raise ValueError(f"{svc.value} not in enabled services") current = self.officer_pool.allocated.get(svc, 0) if current + change < 0: raise ValueError( f"Cannot reduce {svc.value} below 0 (current={current}, change={change})" ) for svc_key, change in delta.items(): svc = ServiceType(svc_key) if isinstance(svc_key, str) else svc_key self.officer_pool.allocated[svc] = self.officer_pool.allocated.get(svc, 0) + change changes = ", ".join(f"{k}:{'+' if v > 0 else ''}{v}" for k, v in delta.items()) notes.append(f"Officers reallocated: {changes}") return notes, day_result raise ValueError(f"Unsupported action_type: {action.action_type.value}") def _advance_one_day(self) -> DayResult: self.day += 1 alloc = dict(self.officer_pool.allocated) result = self.simulator.simulate_day( day=self.day, active_cases=self.active_cases, completed_cases=self.completed_cases, priority_mode=self.priority_mode, officer_allocations=alloc, ) for case in self.completed_cases: if getattr(case, "_counted", False): continue case._counted = True svc = case.service_type self.completed_by_service[svc] = self.completed_by_service.get(svc, 0) + 1 for case in self.active_cases: if getattr(case, "_arrival_counted", False): continue case._arrival_counted = True svc = case.service_type self.arrived_by_service[svc] = self.arrived_by_service.get(svc, 0) + 1 self.metrics.total_arrived += 1 if case.is_urgent: self.metrics.total_urgent_arrived += 1 self.metrics.total_completed = len(self.completed_cases) self.metrics.total_sla_breaches += result.new_sla_breaches self.metrics.total_idle_officer_days += result.idle_officer_days self.metrics.total_capacity_days += result.total_capacity_days self.metrics.total_urgent_completed += result.urgent_completed self.metrics.total_docs_cleared += result.newly_unblocked_missing return result def _build_observation(self, active_events: list = None) -> ObservationModel: active_events = active_events or [] snapshots: dict[str, QueueSnapshot] = {} todays_digital = 0 todays_arrivals = 0 today_completed: dict[ServiceType, int] = {} for case in self.completed_cases: today_completed[case.service_type] = today_completed.get(case.service_type, 0) + 1 for service in self.task.enabled_services: snap = self.simulator.build_queue_snapshot(service, self.active_cases, self.day) snap.total_completed_today = today_completed.get(service, 0) snapshots[service.value] = snap for case in self.active_cases: if case.arrival_day == self.day: todays_arrivals += 1 if case.intake_channel.value == "digital": todays_digital += 1 sigs = self.signal_computer.compute( queue_snapshots=snapshots, officer_pool=self.officer_pool, todays_arrivals=todays_arrivals, digital_arrivals=todays_digital, capacity_per_day=max(1.0, float(self.officer_pool.available_officers)), ) pending_doc = sum( 1 for c in self.active_cases if c.internal_substate == InternalSubstate.BLOCKED_MISSING_DOCS and c.doc_resolution_day is not None ) pending_officer = len(getattr(self.officer_pool, "pending_reallocation", {})) return ObservationModel( task_id=self.task_id, episode_id=self.episode_id, day=self.day, max_days=self.task.max_days, scenario_mode=self.task.scenario_mode, officer_pool=self.officer_pool.model_copy(deep=True), queue_snapshots=snapshots, total_backlog=len(self.active_cases), total_completed=len(self.completed_cases), total_sla_breaches=self.metrics.total_sla_breaches, total_rejected=self.metrics.total_rejected, escalation_budget_remaining=self.escalation_budget_remaining, backlog_pressure=sigs.backlog_pressure, sla_risk_score=sigs.sla_risk_score, fairness_index=sigs.fairness_index, resource_utilization=sigs.resource_utilization, digital_intake_ratio=sigs.digital_intake_ratio, blocked_cases_missing_docs=sigs.blocked_cases_missing_docs, field_verification_load=sigs.field_verification_load, active_events=active_events, last_action_valid=self.last_action_valid, last_action_message=self.last_action_message, last_action_explanation=self.last_action_explanation, pending_doc_resolutions=pending_doc, pending_officer_reallocations=pending_officer, ) def _init_episode_state(self) -> None: self.seed = self.task.seed self.rng = random.Random(self.seed) self.episode_id = f"{self.task_id}-s{self.seed}-init" self.day = 0 self.total_steps = 0 self.terminated = False self.truncated = False self.priority_mode = PriorityMode.BALANCED self.officer_pool = OfficerPool( total_officers=1, available_officers=1, allocated={}, pending_reallocation={}, ) self.active_cases: list[ApplicationCase] = [] self.completed_cases: list[ApplicationCase] = [] self.escalation_budget_remaining = 0 self.arrived_by_service: dict[ServiceType, int] = {} self.completed_by_service: dict[ServiceType, int] = {} self.metrics = EpisodeMetrics() self.action_history: list[dict] = [] self.last_action_valid = True self.last_action_message = "" self.last_action_explanation = "" self.event_engine = EventEngine(seed=self.seed, scenario_mode=ScenarioMode.NORMAL) self.simulator = DaySimulator(self.task, self.rng, self.event_engine) self.signal_computer = SignalComputer() def _count_pending_effects(self) -> int: doc_pending = sum( 1 for c in self.active_cases if c.doc_resolution_day is not None and c.internal_substate == InternalSubstate.BLOCKED_MISSING_DOCS ) fv_pending = sum( 1 for c in self.active_cases if c.internal_substate == InternalSubstate.FIELD_VERIFICATION_PENDING and c.field_verification_completion_day is not None ) return doc_pending + fv_pending @property def fairness_gap(self) -> float: return completion_fairness_gap(self.arrived_by_service, self.completed_by_service) @property def total_completed(self) -> int: return len(self.completed_cases) @property def total_backlog(self) -> int: return len(self.active_cases)