Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| import os | |
| import random | |
| import re | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING, Any, Literal, Optional | |
| from openai import OpenAI | |
| from app.event_engine import EventEngine | |
| from app.models import ( | |
| ActionModel, | |
| ActionType, | |
| ApplicationCase, | |
| DelayedEffect, | |
| EventType, | |
| IntakeChannel, | |
| InternalSubstate, | |
| ObservationModel, | |
| PriorityMode, | |
| QueueSnapshot, | |
| ServiceType, | |
| StageType, | |
| ) | |
| from app.sector_profiles import get_sector_profile | |
| from app.state_machine import can_advance | |
| if TYPE_CHECKING: | |
| from app.models import TaskConfig | |
| LEGACY_NVIDIA_MODEL_POOL = [ | |
| "meta/llama-3.3-70b-instruct", | |
| "qwen/qwen3-next-80b-a3b-instruct", | |
| "moonshotai/kimi-k2-instruct-0905", | |
| "meta/llama-3.1-405b-instruct", | |
| "deepseek-ai/deepseek-v3.2", | |
| "qwen/qwq-32b", | |
| "mistralai/mixtral-8x22b-instruct-v0.1", | |
| "google/gemma-3-27b-it", | |
| "microsoft/phi-4-mini-instruct", | |
| "meta/llama-3.1-8b-instruct", | |
| ] | |
| _MODEL_CACHE: dict[tuple[str, str], Any] = {} | |
| # ───────────────────────────────────────────── | |
| # DAY RESULT | |
| # ───────────────────────────────────────────── | |
| class DayResult: | |
| def __init__(self) -> None: | |
| self.new_arrivals: int = 0 | |
| self.new_completions: int = 0 | |
| self.new_sla_breaches: int = 0 | |
| self.total_capacity_days: int = 0 | |
| self.idle_officer_days: int = 0 | |
| self.stage_advances: int = 0 | |
| self.newly_unblocked_missing: int = 0 | |
| self.newly_blocked_missing: int = 0 | |
| self.newly_unblocked_enrich: int = 0 | |
| self.field_verif_completed: int = 0 | |
| self.urgent_completed: int = 0 | |
| self.digital_arrivals: int = 0 | |
| self.active_events: list[EventType] = [] | |
| # ───────────────────────────────────────────── | |
| # DAY SIMULATOR | |
| # ───────────────────────────────────────────── | |
| class DaySimulator: | |
| """ | |
| Core daily simulation engine. | |
| Accepts TWO calling conventions so both env.py and tests work: | |
| Convention A (tests): | |
| DaySimulator(task_config=task, rng=rng, event_engine=engine) | |
| Convention B (env.py legacy): | |
| DaySimulator(seed=42, task_config=task, sector_registry={}) | |
| — in this case rng and event_engine are built internally. | |
| """ | |
| def __init__( | |
| self, | |
| task_config: "TaskConfig", | |
| rng: Optional[random.Random] = None, | |
| event_engine: Optional[EventEngine] = None, | |
| seed: Optional[int] = None, | |
| sector_registry: Optional[dict] = None, | |
| ) -> None: | |
| self.task_config = task_config | |
| self.task = task_config | |
| if rng is not None: | |
| self.rng = rng | |
| elif seed is not None: | |
| self.rng = random.Random(seed) | |
| else: | |
| self.rng = random.Random(task_config.seed) | |
| if event_engine is not None: | |
| self.event_engine = event_engine | |
| else: | |
| _seed = seed if seed is not None else task_config.seed | |
| self.event_engine = EventEngine( | |
| seed=_seed, | |
| scenario_mode=task_config.scenario_mode, | |
| ) | |
| self.sector_registry = sector_registry or {} | |
| self.active_cases: list[ApplicationCase] = [] | |
| self.pending_effects: list[DelayedEffect] = [] | |
| self.case_counter: int = 0 | |
| def simulate_day( | |
| self, | |
| day: int, | |
| active_cases: list[ApplicationCase], | |
| completed_cases: list[ApplicationCase], | |
| priority_mode: PriorityMode, | |
| officer_allocations: dict, | |
| ) -> DayResult: | |
| result = DayResult() | |
| events = self.event_engine.get_events_for_day(day, self.task_config) | |
| params = self.event_engine.apply_events(events, self.task_config) | |
| result.active_events = list(params.active_events) | |
| new_cases = self._spawn_arrivals(day, params, result) | |
| active_cases.extend(new_cases) | |
| effective_alloc = self._apply_officer_reduction(officer_allocations, params) | |
| self._resolve_field_verification(day, active_cases, result) | |
| self._resolve_doc_requests(day, active_cases, result) | |
| newly_completed: list[ApplicationCase] = [] | |
| for service in self.task_config.enabled_services: | |
| capacity = effective_alloc.get(service, effective_alloc.get(service.value, 0)) | |
| result.total_capacity_days += int(capacity) | |
| service_cases = [ | |
| c | |
| for c in active_cases | |
| if c.service_type == service and not c.completed and not c.rejected | |
| ] | |
| if not service_cases: | |
| result.idle_officer_days += int(capacity) | |
| continue | |
| sorted_cases = self._sort_queue(service_cases, priority_mode) | |
| for case in sorted_cases: | |
| if capacity <= 0: | |
| break | |
| from app.state_machine import advance_case | |
| advanced, final = advance_case(case, day) | |
| if advanced: | |
| capacity -= 1 | |
| result.stage_advances += 1 | |
| if final: | |
| newly_completed.append(case) | |
| if case.is_urgent: | |
| result.urgent_completed += 1 | |
| if newly_completed: | |
| done_ids = {c.case_id for c in newly_completed} | |
| still_active = [c for c in active_cases if c.case_id not in done_ids] | |
| active_cases.clear() | |
| active_cases.extend(still_active) | |
| completed_cases.extend(newly_completed) | |
| result.new_completions = len(newly_completed) | |
| for case in active_cases: | |
| case.current_day = day | |
| case.waiting_days += 1 | |
| if day > case.sla_deadline_day and not case.sla_breached: | |
| case.sla_breached = True | |
| result.new_sla_breaches += 1 | |
| return result | |
| def _apply_officer_reduction(self, allocations: dict, params: Any) -> dict: | |
| reduction = int(getattr(params, "officer_reduction", 0)) | |
| if reduction <= 0: | |
| return dict(allocations) | |
| effective = dict(allocations) | |
| for _ in range(reduction): | |
| target = max(effective, key=lambda k: effective[k], default=None) | |
| if target is None or effective[target] <= 0: | |
| break | |
| effective[target] -= 1 | |
| return effective | |
| def _spawn_arrivals( | |
| self, | |
| day: int, | |
| params: Any, | |
| result: DayResult, | |
| ) -> list[ApplicationCase]: | |
| new_cases: list[ApplicationCase] = [] | |
| for service in self.task_config.enabled_services: | |
| base_rate = self.task_config.arrival_rate_per_day.get( | |
| service, | |
| self.task_config.arrival_rate_per_day.get(service.value, 0.0), | |
| ) | |
| effective_rate = float(base_rate) * float(getattr(params, "arrival_multiplier", 1.0)) | |
| count = int(effective_rate) | |
| if self.rng.random() < (effective_rate - count): | |
| count += 1 | |
| for _ in range(count): | |
| case = self._new_case(service, day, params) | |
| new_cases.append(case) | |
| if case.intake_channel == IntakeChannel.DIGITAL: | |
| result.digital_arrivals += 1 | |
| result.new_arrivals = len(new_cases) | |
| return new_cases | |
| def _new_case(self, service: ServiceType, day: int, params: Any) -> ApplicationCase: | |
| self.case_counter += 1 | |
| profile = get_sector_profile(service) | |
| sla_days = int(profile.sla_days * getattr(params, "sla_window_multiplier", 1.0)) | |
| sla_deadline_day = day + sla_days | |
| digital_ratio = self.task_config.digital_intake_ratio | |
| channel = ( | |
| IntakeChannel.DIGITAL | |
| if self.rng.random() < digital_ratio | |
| else IntakeChannel.PAPER | |
| ) | |
| base_missing = profile.missing_docs_probability | |
| override = (self.task_config.missing_docs_probability_override or {}).get( | |
| service, | |
| (self.task_config.missing_docs_probability_override or {}).get(service.value), | |
| ) | |
| if override is not None: | |
| base_missing = override | |
| defect_rate = ( | |
| profile.doc_defect_rate_digital | |
| if channel == IntakeChannel.DIGITAL | |
| else profile.doc_defect_rate_paper | |
| ) | |
| eff_missing = min( | |
| 1.0, | |
| base_missing + getattr(params, "doc_defect_rate_boost", 0.0) * defect_rate, | |
| ) | |
| has_missing = self.rng.random() < eff_missing | |
| base_fv = profile.field_verification_probability | |
| fv_override = (self.task_config.field_verification_probability_override or {}).get( | |
| service, | |
| (self.task_config.field_verification_probability_override or {}).get(service.value), | |
| ) | |
| if fv_override is not None: | |
| base_fv = fv_override | |
| eff_fv = min(1.0, base_fv + getattr(params, "field_verification_boost", 0.0)) | |
| has_fv = self.rng.random() < eff_fv | |
| field_completion_day = day + profile.field_verification_days if has_fv else None | |
| from app.models import UrgencyProfile | |
| urgency_profile = profile.urgency_profile | |
| is_urgent = ( | |
| urgency_profile == UrgencyProfile.HIGH and self.rng.random() < 0.20 | |
| ) or ( | |
| urgency_profile == UrgencyProfile.MODERATE and self.rng.random() < 0.08 | |
| ) | |
| return ApplicationCase( | |
| case_id=f"case-{self.case_counter:06d}", | |
| service_type=service, | |
| arrival_day=day, | |
| current_day=day, | |
| sla_deadline_day=sla_deadline_day, | |
| intake_channel=channel, | |
| internal_substate=( | |
| InternalSubstate.BLOCKED_MISSING_DOCS | |
| if has_missing | |
| else InternalSubstate.PRE_SCRUTINY | |
| ), | |
| public_stage=StageType.SUBMISSION, | |
| is_urgent=is_urgent, | |
| has_missing_docs=has_missing, | |
| field_verification_required=has_fv, | |
| field_verification_completion_day=field_completion_day, | |
| ) | |
| def _resolve_field_verification( | |
| self, | |
| day: int, | |
| active_cases: list[ApplicationCase], | |
| result: DayResult, | |
| ) -> None: | |
| for case in active_cases: | |
| if ( | |
| case.internal_substate == InternalSubstate.FIELD_VERIFICATION_PENDING | |
| and case.field_verification_completion_day is not None | |
| and day >= case.field_verification_completion_day | |
| ): | |
| case.internal_substate = InternalSubstate.PRE_SCRUTINY | |
| case.field_verification_completion_day = None | |
| result.field_verif_completed += 1 | |
| def _resolve_doc_requests( | |
| self, | |
| day: int, | |
| active_cases: list[ApplicationCase], | |
| result: DayResult, | |
| ) -> None: | |
| for case in active_cases: | |
| if ( | |
| case.internal_substate == InternalSubstate.BLOCKED_MISSING_DOCS | |
| and case.doc_resolution_day is not None | |
| and day >= case.doc_resolution_day | |
| ): | |
| case.internal_substate = InternalSubstate.PRE_SCRUTINY | |
| case.doc_resolution_day = None | |
| result.newly_unblocked_missing += 1 | |
| def _sort_queue( | |
| self, | |
| cases: list[ApplicationCase], | |
| priority_mode: PriorityMode, | |
| ) -> list[ApplicationCase]: | |
| eligible = [c for c in cases if can_advance(c)] | |
| if priority_mode == PriorityMode.URGENT_FIRST: | |
| return sorted( | |
| eligible, | |
| key=lambda c: (not c.is_urgent, -c.sla_risk, c.arrival_day), | |
| ) | |
| if priority_mode == PriorityMode.OLDEST_FIRST: | |
| return sorted(eligible, key=lambda c: c.arrival_day) | |
| if priority_mode == PriorityMode.BACKLOG_CLEARANCE: | |
| return sorted( | |
| eligible, | |
| key=lambda c: (-c.sla_risk, not c.is_urgent, c.arrival_day), | |
| ) | |
| return sorted( | |
| eligible, | |
| key=lambda c: ( | |
| -c.sla_risk if c.sla_risk > 0.8 else 0, | |
| not c.is_urgent, | |
| c.arrival_day, | |
| ), | |
| ) | |
| def build_queue_snapshot( | |
| self, | |
| service: ServiceType, | |
| active_cases: list[ApplicationCase], | |
| day: int, | |
| ) -> QueueSnapshot: | |
| cases = [ | |
| c | |
| for c in active_cases | |
| if c.service_type == service and not c.completed and not c.rejected | |
| ] | |
| stage_counts = {s.value: 0 for s in StageType} | |
| for c in cases: | |
| stage_counts[c.public_stage.value] = stage_counts.get(c.public_stage.value, 0) + 1 | |
| oldest_age = max((c.waiting_days for c in cases), default=0) | |
| avg_wait = sum(c.waiting_days for c in cases) / len(cases) if cases else 0.0 | |
| sla_risk = sum(c.sla_risk for c in cases) / len(cases) if cases else 0.0 | |
| return QueueSnapshot( | |
| service_type=service, | |
| public_stage_counts=stage_counts, | |
| total_pending=len(cases), | |
| total_completed_today=0, | |
| total_sla_breached=sum(1 for c in cases if c.sla_breached), | |
| urgent_pending=sum(1 for c in cases if c.is_urgent), | |
| blocked_missing_docs=sum( | |
| 1 | |
| for c in cases | |
| if c.internal_substate == InternalSubstate.BLOCKED_MISSING_DOCS | |
| ), | |
| field_verification_pending=sum( | |
| 1 | |
| for c in cases | |
| if c.internal_substate == InternalSubstate.FIELD_VERIFICATION_PENDING | |
| ), | |
| oldest_case_age_days=oldest_age, | |
| avg_waiting_days=round(avg_wait, 2), | |
| current_sla_risk=round(min(1.0, sla_risk), 3), | |
| ) | |
| # ───────────────────────────────────────────── | |
| # HIGH-LEVEL SIMULATION ORCHESTRATION | |
| # ───────────────────────────────────────────── | |
| class SimulationAgentMode(str, Enum): | |
| BASELINE_POLICY = "baseline_policy" | |
| LLM_INFERENCE = "llm_inference" | |
| TRAINED_RL = "trained_rl" | |
| class SimulationRun: | |
| task_id: str | |
| agent_mode: SimulationAgentMode | |
| seed: int | |
| total_reward: float | |
| score: float | |
| grader_name: str | |
| summary: dict[str, Any] | |
| trace: list[dict[str, Any]] | |
| def _dedupe(values: list[str | None]) -> list[str]: | |
| out: list[str] = [] | |
| for value in values: | |
| if value is None: | |
| continue | |
| v = str(value).strip() | |
| if v and v not in out: | |
| out.append(v) | |
| return out | |
| def _env_csv_list(name: str) -> list[str]: | |
| raw = os.getenv(name, "").strip() | |
| if not raw: | |
| return [] | |
| return [x.strip() for x in raw.split(",") if x.strip()] | |
| def _extract_json_object(text: str) -> dict[str, Any] | None: | |
| text = (text or "").strip() | |
| if not text: | |
| return None | |
| try: | |
| parsed = json.loads(text) | |
| if isinstance(parsed, dict): | |
| return parsed | |
| except json.JSONDecodeError: | |
| pass | |
| match = re.search(r"\{.*\}", text, flags=re.DOTALL) | |
| if not match: | |
| return None | |
| try: | |
| parsed = json.loads(match.group(0)) | |
| except json.JSONDecodeError: | |
| return None | |
| return parsed if isinstance(parsed, dict) else None | |
| def _enum_service(value: Any) -> ServiceType | None: | |
| if value is None or value == "": | |
| return None | |
| if isinstance(value, ServiceType): | |
| return value | |
| try: | |
| return ServiceType(str(value)) | |
| except Exception: | |
| return None | |
| def _enum_priority(value: Any) -> PriorityMode | None: | |
| if value is None or value == "": | |
| return None | |
| if isinstance(value, PriorityMode): | |
| return value | |
| try: | |
| return PriorityMode(str(value)) | |
| except Exception: | |
| return None | |
| def _action_model_from_kwargs(action_type: ActionType, **kwargs: Any) -> ActionModel: | |
| service = _enum_service(kwargs.get("service") or kwargs.get("service_target")) | |
| target_service = _enum_service(kwargs.get("target_service")) | |
| escalation_target = _enum_service(kwargs.get("escalation_target")) | |
| priority_mode = _enum_priority(kwargs.get("priority_mode")) | |
| officer_delta = kwargs.get("officer_delta") | |
| case_id = kwargs.get("case_id") | |
| candidates: list[dict[str, Any]] = [] | |
| if action_type == ActionType.ADVANCE_TIME: | |
| candidates.append({"action_type": action_type}) | |
| elif action_type == ActionType.SET_PRIORITY_MODE: | |
| candidates.extend( | |
| [ | |
| {"action_type": action_type, "priority_mode": priority_mode}, | |
| ] | |
| ) | |
| elif action_type == ActionType.ASSIGN_CAPACITY: | |
| if service is not None: | |
| delta = max(1, int(officer_delta or 1)) | |
| candidates.extend( | |
| [ | |
| {"action_type": action_type, "service": service, "officer_delta": delta}, | |
| {"action_type": action_type, "service_target": service, "officer_delta": delta}, | |
| { | |
| "action_type": action_type, | |
| "capacity_assignment": {service.value: delta}, | |
| }, | |
| ] | |
| ) | |
| elif action_type == ActionType.REQUEST_MISSING_DOCUMENTS: | |
| if service is not None: | |
| candidates.extend( | |
| [ | |
| {"action_type": action_type, "service": service}, | |
| {"action_type": action_type, "service_target": service}, | |
| ] | |
| ) | |
| elif action_type == ActionType.ESCALATE_SERVICE: | |
| svc = escalation_target or service | |
| candidates.extend( | |
| [ | |
| {"action_type": action_type, "service": svc, "case_id": case_id}, | |
| {"action_type": action_type, "service_target": svc, "case_id": case_id}, | |
| {"action_type": action_type, "escalation_target": svc, "case_id": case_id}, | |
| ] | |
| ) | |
| elif action_type == ActionType.REALLOCATE_OFFICERS: | |
| if service is not None and target_service is not None: | |
| delta = max(1, int(officer_delta or 1)) | |
| candidates.extend( | |
| [ | |
| { | |
| "action_type": action_type, | |
| "service": service, | |
| "target_service": target_service, | |
| "officer_delta": delta, | |
| }, | |
| { | |
| "action_type": action_type, | |
| "reallocation_delta": { | |
| service.value: -delta, | |
| target_service.value: delta, | |
| }, | |
| }, | |
| ] | |
| ) | |
| for candidate in candidates: | |
| try: | |
| return ActionModel(**candidate) | |
| except Exception: | |
| continue | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME) | |
| def _coerce_action(payload: dict[str, Any] | None) -> ActionModel: | |
| if not payload: | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME) | |
| raw_action_type = payload.get("action_type") or payload.get("actionType") | |
| try: | |
| action_type = ActionType(str(raw_action_type)) | |
| except Exception: | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME) | |
| service = payload.get("service") or payload.get("service_target") or payload.get("serviceTarget") | |
| target_service = payload.get("target_service") or payload.get("targetService") | |
| escalation_target = payload.get("escalation_target") or payload.get("escalationTarget") | |
| priority_mode = payload.get("priority_mode") or payload.get("priorityMode") | |
| officer_delta = payload.get("officer_delta") or payload.get("officerDelta") | |
| case_id = payload.get("case_id") or payload.get("caseId") | |
| if action_type == ActionType.ASSIGN_CAPACITY and not service: | |
| assignment = payload.get("capacity_assignment") or {} | |
| if isinstance(assignment, dict) and assignment: | |
| service, officer_delta = next(iter(assignment.items())) | |
| if action_type == ActionType.REALLOCATE_OFFICERS and (not service or not target_service): | |
| delta_map = payload.get("reallocation_delta") or {} | |
| if isinstance(delta_map, dict) and len(delta_map) >= 2: | |
| negatives = [k for k, v in delta_map.items() if int(v) < 0] | |
| positives = [k for k, v in delta_map.items() if int(v) > 0] | |
| if negatives and positives: | |
| service = negatives[0] | |
| target_service = positives[0] | |
| officer_delta = abs(int(delta_map[service])) | |
| return _action_model_from_kwargs( | |
| action_type, | |
| service=service, | |
| target_service=target_service, | |
| escalation_target=escalation_target, | |
| priority_mode=priority_mode, | |
| officer_delta=officer_delta, | |
| case_id=case_id, | |
| ) | |
| def _recommended_min_steps(task_id: str) -> int: | |
| if task_id == "cross_department_hard": | |
| return 70 | |
| if task_id == "mixed_urgency_medium": | |
| return 60 | |
| return 40 | |
| def _queue_snapshot_iter(obs: ObservationModel) -> list[Any]: | |
| raw = getattr(obs, "queue_snapshots", []) | |
| if isinstance(raw, dict): | |
| return list(raw.values()) | |
| if isinstance(raw, list): | |
| return list(raw) | |
| try: | |
| return list(raw) | |
| except Exception: | |
| return [] | |
| def _queue_service(q: Any) -> ServiceType | None: | |
| return _enum_service(getattr(q, "service", None) or getattr(q, "service_type", None)) | |
| def _queue_active_cases(q: Any) -> int: | |
| return int(getattr(q, "active_cases", getattr(q, "total_pending", 0)) or 0) | |
| def _queue_missing_docs(q: Any) -> int: | |
| return int(getattr(q, "missing_docs_cases", getattr(q, "blocked_missing_docs", 0)) or 0) | |
| def _queue_urgent_cases(q: Any) -> int: | |
| return int(getattr(q, "urgent_cases", getattr(q, "urgent_pending", 0)) or 0) | |
| def _queue_breached_cases(q: Any) -> int: | |
| return int(getattr(q, "breached_cases", getattr(q, "total_sla_breached", 0)) or 0) | |
| def _queue_avg_age(q: Any) -> float: | |
| if hasattr(q, "avg_age_days"): | |
| return float(getattr(q, "avg_age_days") or 0.0) | |
| if hasattr(q, "oldest_case_age_days"): | |
| return float(getattr(q, "oldest_case_age_days") or 0.0) | |
| return float(getattr(q, "avg_waiting_days", 0.0) or 0.0) | |
| def _queue_rows(obs: ObservationModel) -> list[dict[str, Any]]: | |
| rows: list[dict[str, Any]] = [] | |
| for q in _queue_snapshot_iter(obs): | |
| service = _queue_service(q) | |
| if service is None: | |
| continue | |
| rows.append( | |
| { | |
| "service": service.value, | |
| "active_cases": _queue_active_cases(q), | |
| "missing_docs_cases": _queue_missing_docs(q), | |
| "urgent_cases": _queue_urgent_cases(q), | |
| "breached_cases": _queue_breached_cases(q), | |
| "avg_age_days": _queue_avg_age(q), | |
| } | |
| ) | |
| return rows | |
| def _pool_allocations(obs: ObservationModel) -> dict[Any, Any]: | |
| pool = getattr(obs, "officer_pool", None) | |
| if pool is None: | |
| return {} | |
| return getattr(pool, "allocations", getattr(pool, "allocated", {})) or {} | |
| def _reserve_officers(obs: ObservationModel) -> int: | |
| pool = getattr(obs, "officer_pool", None) | |
| if pool is None: | |
| return 0 | |
| for name in ("reserve_officers", "idle_officers", "available_officers"): | |
| if hasattr(pool, name): | |
| try: | |
| return int(getattr(pool, name) or 0) | |
| except Exception: | |
| pass | |
| return 0 | |
| def _alloc_for(obs: ObservationModel, service: ServiceType) -> int: | |
| allocs = _pool_allocations(obs) | |
| raw = allocs.get(service) | |
| if raw is None: | |
| raw = allocs.get(service.value, 0) | |
| return int(raw or 0) | |
| def _top_backlog_service( | |
| obs: ObservationModel, | |
| *, | |
| exclude: ServiceType | None = None, | |
| ) -> ServiceType | None: | |
| ranked: list[Any] = [] | |
| for q in _queue_snapshot_iter(obs): | |
| service = _queue_service(q) | |
| if service is None or service == exclude: | |
| continue | |
| ranked.append(q) | |
| if not ranked: | |
| return None | |
| ranked.sort( | |
| key=lambda q: ( | |
| _queue_active_cases(q) + (2 * _queue_breached_cases(q)) + _queue_urgent_cases(q), | |
| _queue_avg_age(q), | |
| ), | |
| reverse=True, | |
| ) | |
| return _queue_service(ranked[0]) | |
| def _service_with_missing_docs(obs: ObservationModel) -> ServiceType | None: | |
| candidates = [q for q in _queue_snapshot_iter(obs) if _queue_missing_docs(q) > 0] | |
| if not candidates: | |
| return None | |
| candidates.sort(key=lambda q: (_queue_missing_docs(q), _queue_active_cases(q)), reverse=True) | |
| return _queue_service(candidates[0]) | |
| def _service_with_officers(obs: ObservationModel) -> ServiceType | None: | |
| services = [s for s in (_queue_service(q) for q in _queue_snapshot_iter(obs)) if s is not None] | |
| services.sort(key=lambda s: _alloc_for(obs, s), reverse=True) | |
| for service in services: | |
| if _alloc_for(obs, service) > 0: | |
| return service | |
| return None | |
| def _compute_action_mask(obs: ObservationModel) -> dict[ActionType, bool]: | |
| has_reserve = _reserve_officers(obs) > 0 | |
| snapshots = _queue_snapshot_iter(obs) | |
| has_missing = any(_queue_missing_docs(q) > 0 for q in snapshots) | |
| has_backlog = any(_queue_active_cases(q) > 0 for q in snapshots) | |
| has_budget = int(getattr(obs, "escalation_budget_remaining", 0) or 0) > 0 | |
| staffed_services = [q for q in snapshots if (_queue_service(q) is not None and _alloc_for(obs, _queue_service(q)) > 0)] | |
| can_reallocate = len(staffed_services) >= 1 and len(snapshots) >= 2 | |
| return { | |
| ActionType.SET_PRIORITY_MODE: True, | |
| ActionType.ADVANCE_TIME: True, | |
| ActionType.ASSIGN_CAPACITY: has_reserve and has_backlog, | |
| ActionType.REQUEST_MISSING_DOCUMENTS: has_missing, | |
| ActionType.ESCALATE_SERVICE: has_budget and has_backlog, | |
| ActionType.REALLOCATE_OFFICERS: can_reallocate, | |
| } | |
| def _masked_action_type_hints(obs: ObservationModel) -> tuple[list[str], list[str]]: | |
| mask = _compute_action_mask(obs) | |
| allowed = [k.value for k, ok in mask.items() if ok] | |
| blocked = [k.value for k, ok in mask.items() if not ok] | |
| return allowed, blocked | |
| def _best_high_impact_action(obs: ObservationModel) -> tuple[ActionModel, str]: | |
| top_backlog = _top_backlog_service(obs) | |
| top_missing = _service_with_missing_docs(obs) | |
| if _reserve_officers(obs) > 0 and top_backlog is not None: | |
| return ( | |
| _action_model_from_kwargs( | |
| ActionType.ASSIGN_CAPACITY, | |
| service=top_backlog, | |
| officer_delta=1, | |
| ), | |
| "high-impact: assign reserve capacity to top backlog service", | |
| ) | |
| if top_missing is not None: | |
| return ( | |
| _action_model_from_kwargs( | |
| ActionType.REQUEST_MISSING_DOCUMENTS, | |
| service=top_missing, | |
| ), | |
| "high-impact: clear missing-document bottleneck", | |
| ) | |
| if int(getattr(obs, "escalation_budget_remaining", 0) or 0) > 0: | |
| hot = sorted( | |
| _queue_snapshot_iter(obs), | |
| key=lambda q: (_queue_breached_cases(q), _queue_active_cases(q), _queue_urgent_cases(q)), | |
| reverse=True, | |
| ) | |
| if hot and (_queue_breached_cases(hot[0]) > 0 or _queue_active_cases(hot[0]) > 0): | |
| service = _queue_service(hot[0]) | |
| if service is not None: | |
| return ( | |
| _action_model_from_kwargs( | |
| ActionType.ESCALATE_SERVICE, | |
| service=service, | |
| ), | |
| "high-impact: escalate highest SLA-risk service", | |
| ) | |
| source = _service_with_officers(obs) | |
| if source is not None and _alloc_for(obs, source) > 0: | |
| target = _top_backlog_service(obs, exclude=source) | |
| if target is not None and target != source: | |
| return ( | |
| _action_model_from_kwargs( | |
| ActionType.REALLOCATE_OFFICERS, | |
| service=source, | |
| target_service=target, | |
| officer_delta=1, | |
| ), | |
| "high-impact: reallocate one officer toward highest backlog", | |
| ) | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME), "fallback: no high-impact action available" | |
| def _repair_action_for_observation( | |
| action: ActionModel, | |
| obs: ObservationModel, | |
| ) -> tuple[ActionModel, str | None]: | |
| mask = _compute_action_mask(obs) | |
| at = action.action_type | |
| if not bool(mask.get(at, True)): | |
| fallback, why = _best_high_impact_action(obs) | |
| return fallback, f"masked {at.value}; {why}" | |
| if at == ActionType.ADVANCE_TIME: | |
| return action, None | |
| if at == ActionType.SET_PRIORITY_MODE: | |
| if getattr(action, "priority_mode", None) is None: | |
| return ( | |
| _action_model_from_kwargs( | |
| ActionType.SET_PRIORITY_MODE, | |
| priority_mode=PriorityMode.BACKLOG_CLEARANCE, | |
| ), | |
| "missing priority_mode, defaulted to backlog_clearance", | |
| ) | |
| return action, None | |
| if at == ActionType.ASSIGN_CAPACITY: | |
| reserve = _reserve_officers(obs) | |
| if reserve <= 0: | |
| fallback, why = _best_high_impact_action(obs) | |
| return fallback, f"reserve officers exhausted; {why}" | |
| service = _enum_service(getattr(action, "service", None) or getattr(action, "service_target", None)) or _top_backlog_service(obs) | |
| if service is None: | |
| fallback, why = _best_high_impact_action(obs) | |
| return fallback, f"no service available for assign_capacity; {why}" | |
| delta = max(1, int(getattr(action, "officer_delta", 1) or 1)) | |
| delta = min(delta, reserve) | |
| repaired = _action_model_from_kwargs( | |
| ActionType.ASSIGN_CAPACITY, | |
| service=service, | |
| officer_delta=delta, | |
| ) | |
| return repaired, "repaired assign_capacity payload" | |
| if at == ActionType.REQUEST_MISSING_DOCUMENTS: | |
| service = _enum_service(getattr(action, "service", None) or getattr(action, "service_target", None)) or _service_with_missing_docs(obs) | |
| if service is None: | |
| fallback, why = _best_high_impact_action(obs) | |
| return fallback, f"no missing-doc queue available; {why}" | |
| repaired = _action_model_from_kwargs( | |
| ActionType.REQUEST_MISSING_DOCUMENTS, | |
| service=service, | |
| ) | |
| return repaired, "repaired request_missing_documents payload" | |
| if at == ActionType.ESCALATE_SERVICE: | |
| if int(getattr(obs, "escalation_budget_remaining", 0) or 0) <= 0: | |
| fallback, why = _best_high_impact_action(obs) | |
| return fallback, f"escalation budget exhausted; {why}" | |
| service = ( | |
| _enum_service(getattr(action, "service", None)) | |
| or _enum_service(getattr(action, "service_target", None)) | |
| or _enum_service(getattr(action, "escalation_target", None)) | |
| or _top_backlog_service(obs) | |
| ) | |
| case_id = getattr(action, "case_id", None) | |
| if service is None and case_id is None: | |
| fallback, why = _best_high_impact_action(obs) | |
| return fallback, f"no escalation target available; {why}" | |
| repaired = _action_model_from_kwargs( | |
| ActionType.ESCALATE_SERVICE, | |
| service=service, | |
| case_id=case_id, | |
| ) | |
| return repaired, "repaired escalate_service payload" | |
| if at == ActionType.REALLOCATE_OFFICERS: | |
| source = _enum_service(getattr(action, "service", None) or getattr(action, "service_target", None)) or _service_with_officers(obs) | |
| if source is None: | |
| fallback, why = _best_high_impact_action(obs) | |
| return fallback, f"no staffed source service; {why}" | |
| source_alloc = _alloc_for(obs, source) | |
| if source_alloc <= 0: | |
| source = _service_with_officers(obs) | |
| source_alloc = _alloc_for(obs, source) if source is not None else 0 | |
| if source is None or source_alloc <= 0: | |
| fallback, why = _best_high_impact_action(obs) | |
| return fallback, f"insufficient source officers; {why}" | |
| target = _enum_service(getattr(action, "target_service", None)) | |
| if target is None or target == source: | |
| target = _top_backlog_service(obs, exclude=source) | |
| if target is None or target == source: | |
| fallback, why = _best_high_impact_action(obs) | |
| return fallback, f"missing distinct target_service; {why}" | |
| delta = max(1, int(getattr(action, "officer_delta", 1) or 1)) | |
| delta = min(delta, source_alloc) | |
| repaired = _action_model_from_kwargs( | |
| ActionType.REALLOCATE_OFFICERS, | |
| service=source, | |
| target_service=target, | |
| officer_delta=delta, | |
| ) | |
| return repaired, "repaired reallocate_officers payload" | |
| return action, None | |
| def _model_label_for_mode(agent_mode: SimulationAgentMode) -> str: | |
| if agent_mode == SimulationAgentMode.BASELINE_POLICY: | |
| return "baseline_policy" | |
| if agent_mode == SimulationAgentMode.TRAINED_RL: | |
| return "trained_rl" | |
| return os.getenv("MODEL_NAME", "llm_inference") | |
| def _log_step_line(step_row: dict[str, Any]) -> str: | |
| done = "true" if bool(step_row.get("done")) else "false" | |
| error = step_row.get("last_action_error") or "null" | |
| action = json.dumps(step_row.get("action_payload", {}), separators=(",", ":")) | |
| source = step_row.get("decision_source") or "unknown" | |
| model = step_row.get("model_used") or "null" | |
| repair = step_row.get("repair_note") or "null" | |
| switch_note = step_row.get("switch_note") or "null" | |
| return ( | |
| f"[STEP] step={step_row.get('step', 0)} action={action} " | |
| f"reward={float(step_row.get('reward', 0.0)):.2f} done={done} " | |
| f"error={error} source={source} model={model} repair={repair} switch={switch_note}" | |
| ) | |
| def _resolve_model_path_or_raise(model_path: str) -> str: | |
| p = Path(model_path).expanduser() | |
| if not p.is_absolute(): | |
| p = (Path.cwd() / p).resolve() | |
| if p.is_dir(): | |
| candidates = [ | |
| p / "best_model.zip", | |
| p / "model.zip", | |
| p / "checkpoint.zip", | |
| ] | |
| zip_files = sorted(p.glob("*.zip")) | |
| candidates.extend(zip_files) | |
| for candidate in candidates: | |
| if candidate.exists(): | |
| return str(candidate) | |
| if p.exists(): | |
| return str(p) | |
| raise FileNotFoundError(f"Model path not found: {model_path}") | |
| def _load_model_cached_or_raise(model_abs: str, model_type: Literal["maskable", "recurrent"]) -> Any: | |
| key = (model_abs, model_type) | |
| if key in _MODEL_CACHE: | |
| return _MODEL_CACHE[key] | |
| if model_type == "recurrent": | |
| from sb3_contrib import RecurrentPPO | |
| model = RecurrentPPO.load(model_abs) | |
| else: | |
| try: | |
| from sb3_contrib import MaskablePPO | |
| model = MaskablePPO.load(model_abs) | |
| except Exception: | |
| from stable_baselines3 import PPO | |
| model = PPO.load(model_abs) | |
| _MODEL_CACHE[key] = model | |
| return model | |
| def _safe_invalid_action_count(final_state: Any) -> int: | |
| if hasattr(final_state, "total_invalid_actions"): | |
| return int(getattr(final_state, "total_invalid_actions") or 0) | |
| metrics = getattr(final_state, "metrics", None) | |
| if metrics is not None and hasattr(metrics, "total_invalid_actions"): | |
| return int(getattr(metrics, "total_invalid_actions") or 0) | |
| return 0 | |
| class LiveSimulationSession: | |
| def __init__( | |
| self, | |
| *, | |
| task_id: str, | |
| agent_mode: SimulationAgentMode, | |
| max_steps: int, | |
| seed: int | None, | |
| policy_name: str | None = None, | |
| model_path: str | None = None, | |
| model_type: Literal["maskable", "recurrent"] = "maskable", | |
| ) -> None: | |
| self.task_id = task_id | |
| self.agent_mode = agent_mode | |
| recommended = _recommended_min_steps(task_id) | |
| self.max_steps = max(int(max_steps), int(recommended)) if agent_mode == SimulationAgentMode.LLM_INFERENCE else int(max_steps) | |
| self.seed = int(seed if seed is not None else random.randint(1, 999999)) | |
| self.policy_name = policy_name or "backlog_clearance" | |
| self.model_path = model_path | |
| self.model_type = model_type | |
| self.trace: list[dict[str, Any]] = [] | |
| self.total_reward = 0.0 | |
| self.step_idx = 0 | |
| self.done = False | |
| self.summary: dict[str, Any] | None = None | |
| self.score: float | None = None | |
| self.grader_name: str | None = None | |
| self.env: Any = None | |
| self.obs: ObservationModel | Any = None | |
| self.policy: Any = None | |
| self.rl_env: Any = None | |
| self.rl_model: Any = None | |
| self.rl_lstm_state: Any = None | |
| self.rl_episode_start: Any = None | |
| self.llm_runtimes: list[dict[str, Any]] = [] | |
| self.llm_route: list[str] = [] | |
| self.llm_model_stats: dict[tuple[str, str], dict[str, Any]] = {} | |
| self.consecutive_failure_steps = 0 | |
| self.recovery_steps_remaining = 0 | |
| self.auto_switch_count = 0 | |
| self.last_switch_reason: str | None = None | |
| if self.agent_mode == SimulationAgentMode.TRAINED_RL: | |
| self._init_trained() | |
| else: | |
| self._init_core() | |
| def start_line(self) -> dict[str, Any]: | |
| return { | |
| "log": ( | |
| f"[START] task={self.task_id} env=gov-workflow-openenv " | |
| f"model={_model_label_for_mode(self.agent_mode)}" | |
| ), | |
| "observation": self.obs | |
| } | |
| def _init_core(self) -> None: | |
| from app.baselines import POLICIES, backlog_clearance_policy | |
| from app.env import GovWorkflowEnv | |
| self.env = GovWorkflowEnv(task_id=self.task_id) | |
| self.obs, _ = self.env.reset(seed=self.seed) | |
| if self.agent_mode == SimulationAgentMode.BASELINE_POLICY: | |
| self.policy = POLICIES.get(self.policy_name, backlog_clearance_policy) | |
| else: | |
| self.policy = self._llm_action_with_meta | |
| self._init_llm_runtimes() | |
| def _init_llm_runtimes(self) -> None: | |
| openai_base = os.getenv("API_BASE_URL") or os.getenv("OPENAI_API_BASE_URL") or "https://api.openai.com/v1" | |
| nvidia_base = os.getenv("NVIDIA_API_BASE_URL", "https://integrate.api.nvidia.com/v1") | |
| openai_keys = _dedupe( | |
| [ | |
| os.getenv("HF_TOKEN"), | |
| os.getenv("OPENAI_API_KEY"), | |
| os.getenv("API_KEY"), | |
| ] | |
| ) | |
| nvidia_keys = _dedupe( | |
| [ | |
| os.getenv("NVIDIA_API_KEY"), | |
| os.getenv("NVIDIA_API_KEY_2"), | |
| ] | |
| ) | |
| openai_models = _dedupe( | |
| [ | |
| os.getenv("MODEL_NAME", "meta/llama-3.3-70b-instruct"), | |
| *_env_csv_list("MODEL_FALLBACKS"), | |
| ] | |
| ) | |
| nvidia_models = _dedupe( | |
| [ | |
| os.getenv("NVIDIA_MODEL"), | |
| *_env_csv_list("NVIDIA_MODEL_FALLBACKS"), | |
| *LEGACY_NVIDIA_MODEL_POOL, | |
| ] | |
| ) | |
| runtimes: list[dict[str, Any]] = [] | |
| if openai_keys and openai_models: | |
| clients: list[tuple[OpenAI, str]] = [] | |
| for idx, key in enumerate(openai_keys, start=1): | |
| try: | |
| clients.append( | |
| ( | |
| OpenAI(base_url=openai_base, api_key=key, timeout=8.0, max_retries=0), | |
| f"openai_key_{idx}", | |
| ) | |
| ) | |
| except Exception: | |
| continue | |
| if clients: | |
| runtimes.append( | |
| { | |
| "provider": "openai-compatible", | |
| "base_url": openai_base, | |
| "clients": clients, | |
| "models": openai_models, | |
| } | |
| ) | |
| if nvidia_keys and nvidia_models: | |
| clients = [] | |
| for idx, key in enumerate(nvidia_keys, start=1): | |
| try: | |
| clients.append( | |
| ( | |
| OpenAI(base_url=nvidia_base, api_key=key, timeout=8.0, max_retries=0), | |
| f"nvidia_key_{idx}", | |
| ) | |
| ) | |
| except Exception: | |
| continue | |
| if clients: | |
| runtimes.append( | |
| { | |
| "provider": "nvidia", | |
| "base_url": nvidia_base, | |
| "clients": clients, | |
| "models": nvidia_models, | |
| } | |
| ) | |
| self.llm_runtimes = runtimes | |
| self.llm_model_stats = {} | |
| for runtime in runtimes: | |
| provider = str(runtime.get("provider")) | |
| for model in runtime.get("models", []): | |
| self.llm_model_stats[(provider, str(model))] = { | |
| "calls": 0, | |
| "invalid": 0, | |
| "repaired": 0, | |
| "failures": 0, | |
| "cooldown_until_step": 0, | |
| } | |
| openai_runtime = next((rt for rt in runtimes if rt.get("provider") == "openai-compatible"), None) | |
| nvidia_runtime = next((rt for rt in runtimes if rt.get("provider") == "nvidia"), None) | |
| openai_route = ( | |
| f"openai-compatible ({len(openai_runtime['clients'])} keys, {len(openai_runtime['models'])} models)" | |
| if openai_runtime is not None | |
| else "openai-compatible (unavailable: missing API key/model)" | |
| ) | |
| nvidia_route = ( | |
| f"nvidia ({len(nvidia_runtime['clients'])} keys, {len(nvidia_runtime['models'])} models)" | |
| if nvidia_runtime is not None | |
| else "nvidia (unavailable: missing API key/model)" | |
| ) | |
| self.llm_route = [ | |
| openai_route, | |
| nvidia_route, | |
| "adaptive ranking: prefer models with lower invalid/repaired rates", | |
| "heuristic fallback (backlog_clearance_policy)", | |
| ] | |
| def _rank_runtime_models(self, provider: str, models: list[str]) -> list[str]: | |
| def _score(model_name: str) -> tuple[float, int]: | |
| stat = self.llm_model_stats.get((provider, model_name), {}) | |
| calls = max(1, int(stat.get("calls", 0))) | |
| invalid_rate = float(stat.get("invalid", 0)) / calls | |
| repaired_rate = float(stat.get("repaired", 0)) / calls | |
| fail_rate = float(stat.get("failures", 0)) / calls | |
| cooldown = int(stat.get("cooldown_until_step", 0)) | |
| cooldown_penalty = 1.0 if self.step_idx < cooldown else 0.0 | |
| return ( | |
| invalid_rate * 2.0 + repaired_rate * 1.25 + fail_rate * 1.5 + cooldown_penalty, | |
| -calls, | |
| ) | |
| return sorted([str(m) for m in models], key=_score) | |
| def _llm_action_with_meta(self, obs: ObservationModel) -> tuple[ActionModel, dict[str, Any]]: | |
| if self.recovery_steps_remaining > 0: | |
| self.recovery_steps_remaining -= 1 | |
| action, why = _best_high_impact_action(obs) | |
| return action, { | |
| "decision_source": "auto_recovery_policy", | |
| "provider": "heuristic", | |
| "model_used": "backlog_clearance_policy", | |
| "llm_attempts": 0, | |
| "llm_error": None, | |
| "llm_key_label": None, | |
| "repair_note": why, | |
| } | |
| attempts = 0 | |
| last_error = "" | |
| allowed_actions, blocked_actions = _masked_action_type_hints(obs) | |
| schema_hint = { | |
| "required_fields": { | |
| "set_priority_mode": ["action_type", "priority_mode"], | |
| "assign_capacity": ["action_type", "service", "officer_delta"], | |
| "request_missing_documents": ["action_type", "service"], | |
| "escalate_service": ["action_type", "service"], | |
| "advance_time": ["action_type"], | |
| "reallocate_officers": ["action_type", "service", "target_service", "officer_delta"], | |
| }, | |
| "allowed_priority_mode": [m.value for m in PriorityMode], | |
| "allowed_services": [s.value for s in ServiceType], | |
| } | |
| system_prompt = ( | |
| "You are controlling a government workflow simulator. " | |
| "Return exactly one JSON object only. No markdown. No explanation. " | |
| "Allowed action_type: set_priority_mode, assign_capacity, request_missing_documents, " | |
| "escalate_service, advance_time, reallocate_officers. " | |
| "Rules: " | |
| "1) reallocate_officers requires service + target_service + officer_delta>0 and source!=target. " | |
| "2) assign_capacity requires service + officer_delta>0. " | |
| "3) request_missing_documents requires service with missing_docs_cases>0. " | |
| "4) set_priority_mode requires priority_mode in [urgent_first, oldest_first, balanced, backlog_clearance]. " | |
| "5) Always prefer high-impact actions that reduce backlog/SLA risk over no-op loops. " | |
| "Use lowercase enum values." | |
| ) | |
| user_prompt = ( | |
| "Observation:\n" | |
| f"{obs.model_dump_json() if hasattr(obs, 'model_dump_json') else json.dumps(getattr(obs, 'dict', lambda: {})())}\n" | |
| f"Allowed action types now: {allowed_actions}\n" | |
| f"Blocked action types now: {blocked_actions}\n" | |
| f"Action schema hints: {json.dumps(schema_hint, separators=(',', ':'))}\n" | |
| f"Last action validity: {getattr(obs, 'last_action_valid', True)}\n" | |
| f"Last action message: {getattr(obs, 'last_action_message', '')}\n" | |
| "Return action JSON." | |
| ) | |
| for runtime in self.llm_runtimes: | |
| provider = str(runtime["provider"]) | |
| ranked_models = self._rank_runtime_models(provider, list(runtime["models"])) | |
| for client, key_label in runtime["clients"]: | |
| for model in ranked_models: | |
| attempts += 1 | |
| stat_key = (provider, model) | |
| try: | |
| out = client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=0.0, | |
| max_tokens=200, | |
| stream=False, | |
| ) | |
| content = (out.choices[0].message.content or "").strip() | |
| action = _coerce_action(_extract_json_object(content)) | |
| if stat_key in self.llm_model_stats: | |
| self.llm_model_stats[stat_key]["calls"] += 1 | |
| return action, { | |
| "decision_source": "llm", | |
| "provider": provider, | |
| "model_used": model, | |
| "llm_attempts": attempts, | |
| "llm_error": None, | |
| "llm_key_label": key_label, | |
| } | |
| except Exception as exc: | |
| last_error = str(exc) | |
| stat = self.llm_model_stats.get(stat_key) | |
| if stat is not None: | |
| stat["calls"] += 1 | |
| stat["failures"] += 1 | |
| if stat["failures"] >= 2: | |
| stat["cooldown_until_step"] = self.step_idx + 5 | |
| continue | |
| action, why = _best_high_impact_action(obs) | |
| if not self.llm_runtimes: | |
| last_error = "No LLM credentials configured." | |
| return action, { | |
| "decision_source": "heuristic_fallback", | |
| "provider": "heuristic", | |
| "model_used": "backlog_clearance_policy", | |
| "llm_attempts": attempts, | |
| "llm_error": last_error or None, | |
| "llm_key_label": None, | |
| "repair_note": why, | |
| } | |
| def _init_trained(self) -> None: | |
| import numpy as np | |
| from rl.gov_workflow_env import GovWorkflowGymEnv | |
| if not self.model_path: | |
| raise ValueError("model_path is required for trained_rl simulation.") | |
| model_abs = _resolve_model_path_or_raise(self.model_path) | |
| self.rl_model = _load_model_cached_or_raise(model_abs, self.model_type) | |
| self.rl_env = GovWorkflowGymEnv( | |
| task_id=self.task_id, | |
| seed=self.seed, | |
| hard_action_mask=True, | |
| ) | |
| self.obs, _ = self.rl_env.reset(seed=self.seed) | |
| self.rl_lstm_state = None | |
| self.rl_episode_start = np.array([True], dtype=bool) | |
| def step_once(self) -> tuple[dict[str, Any], str, bool]: | |
| if self.done: | |
| raise RuntimeError("Simulation already finished.") | |
| self.step_idx += 1 | |
| row = self._step_trained() if self.agent_mode == SimulationAgentMode.TRAINED_RL else self._step_core() | |
| self.trace.append(row) | |
| self.total_reward += float(row["reward"]) | |
| step_log = _log_step_line(row) | |
| if row["done"] or self.step_idx >= self.max_steps: | |
| self._finalize() | |
| row["done"] = True | |
| return row, step_log, True | |
| return row, step_log, False | |
| def end_line(self) -> str: | |
| if self.score is None: | |
| return "[END] success=false steps=0 score=0.00 rewards=" | |
| rewards = ",".join(f"{float(x.get('reward', 0.0)):.2f}" for x in self.trace) | |
| success = "true" if self.score >= 0.5 else "false" | |
| return f"[END] success={success} steps={len(self.trace)} score={self.score:.2f} rewards={rewards}" | |
| def step_line(self, action: dict | ActionModel) -> dict[str, Any]: | |
| """Test wrapper for executing an action and returning observation + reward.""" | |
| if isinstance(action, dict): | |
| action = _coerce_action(action) | |
| self.obs, reward, terminated, truncated, info = self.env.step(action) | |
| return {"observation": self.obs, "reward": reward} | |
| def snapshot(self) -> dict[str, Any]: | |
| return { | |
| "task_id": self.task_id, | |
| "agent_mode": self.agent_mode.value, | |
| "seed": self.seed, | |
| "max_steps": self.max_steps, | |
| "step_idx": self.step_idx, | |
| "done": self.done, | |
| "total_reward": float(self.total_reward), | |
| "score": self.score, | |
| "grader_name": self.grader_name, | |
| "summary": self.summary, | |
| "trace_len": len(self.trace), | |
| "llm_route": list(self.llm_route), | |
| } | |
| def close(self) -> None: | |
| try: | |
| if self.env is not None and hasattr(self.env, "close"): | |
| self.env.close() | |
| except Exception: | |
| pass | |
| try: | |
| if self.rl_env is not None and hasattr(self.rl_env, "close"): | |
| self.rl_env.close() | |
| except Exception: | |
| pass | |
| def _step_core(self) -> dict[str, Any]: | |
| if self.env is None: | |
| raise RuntimeError("Core simulation env not initialized.") | |
| if self.agent_mode == SimulationAgentMode.BASELINE_POLICY: | |
| action = self.policy(self.obs) | |
| meta = { | |
| "decision_source": "baseline_policy", | |
| "provider": "local_policy", | |
| "model_used": self.policy_name, | |
| "llm_attempts": 0, | |
| "llm_error": None, | |
| "llm_key_label": None, | |
| } | |
| else: | |
| raw_decision = self.policy(self.obs) | |
| if isinstance(raw_decision, tuple) and len(raw_decision) == 2: | |
| action, meta = raw_decision | |
| else: | |
| action, meta = raw_decision, {} | |
| if not isinstance(meta, dict): | |
| meta = {} | |
| if not isinstance(action, ActionModel): | |
| if isinstance(action, dict): | |
| action = _coerce_action(action) | |
| else: | |
| action = ActionModel(action_type=ActionType.ADVANCE_TIME) | |
| meta["repair_note"] = "non-action output from llm policy, coerced to advance_time" | |
| allowed_mask = _compute_action_mask(self.obs) | |
| if not bool(allowed_mask.get(action.action_type, True)): | |
| masked_fallback, why = _best_high_impact_action(self.obs) | |
| action = masked_fallback | |
| if meta.get("decision_source") == "llm": | |
| meta["decision_source"] = "llm_repaired" | |
| meta["repair_note"] = f"action masked at runtime; {why}" | |
| repaired_action, repair_note = _repair_action_for_observation(action, self.obs) | |
| if repair_note: | |
| action = repaired_action | |
| if meta.get("decision_source") == "llm": | |
| meta["decision_source"] = "llm_repaired" | |
| meta["repair_note"] = repair_note | |
| self.obs, reward, terminated, truncated, info = self.env.step(action) | |
| done = bool(terminated or truncated) | |
| last_action_error = getattr(info, "last_action_error", None) | |
| if last_action_error is None: | |
| last_action_error = getattr(info, "action_explanation", None) | |
| row = { | |
| "step": self.step_idx, | |
| "day": self.obs.day, | |
| "action_type": action.action_type.value, | |
| "action_payload": action.model_dump(exclude_none=True, mode="json"), | |
| "reward": float(reward), | |
| "done": done, | |
| "backlog": getattr(self.obs, "total_backlog", 0), | |
| "completed": getattr(self.obs, "total_completed", 0), | |
| "sla_breaches": getattr(self.obs, "total_sla_breaches", 0), | |
| "fairness_gap": float( | |
| getattr(self.obs, "fairness_gap", getattr(self.obs, "fairness_index", 0.0)) or 0.0 | |
| ), | |
| "escalation_budget_remaining": getattr(self.obs, "escalation_budget_remaining", 0), | |
| "invalid_action": bool(getattr(info, "invalid_action", False)), | |
| "last_action_error": last_action_error, | |
| "queue_rows": _queue_rows(self.obs), | |
| } | |
| row.update(meta) | |
| if self.agent_mode == SimulationAgentMode.LLM_INFERENCE: | |
| is_repaired = row.get("decision_source") in {"llm_repaired", "auto_recovery_policy"} | |
| is_invalid = bool(row.get("invalid_action")) or bool(row.get("last_action_error")) | |
| model_used = str(row.get("model_used") or "") | |
| provider = str(row.get("provider") or "") | |
| stat_key = (provider, model_used) | |
| stat = self.llm_model_stats.get(stat_key) | |
| if stat is not None: | |
| if is_repaired: | |
| stat["repaired"] += 1 | |
| if is_invalid: | |
| stat["invalid"] += 1 | |
| stat["failures"] += 1 | |
| else: | |
| stat["failures"] = max(0, int(stat.get("failures", 0)) - 1) | |
| is_failure_pattern = is_invalid or is_repaired | |
| self.consecutive_failure_steps = self.consecutive_failure_steps + 1 if is_failure_pattern else 0 | |
| if self.consecutive_failure_steps >= 4: | |
| if stat is not None: | |
| stat["cooldown_until_step"] = self.step_idx + 6 | |
| self.recovery_steps_remaining = max(self.recovery_steps_remaining, 3) | |
| self.auto_switch_count += 1 | |
| self.last_switch_reason = "repeated invalid/repaired pattern detected" | |
| row["switch_note"] = "auto-switched to recovery policy and deprioritized failing model" | |
| self.consecutive_failure_steps = 0 | |
| return row | |
| def _step_trained(self) -> dict[str, Any]: | |
| import numpy as np | |
| masks = self.rl_env.action_masks() | |
| if self.model_type == "recurrent": | |
| action, self.rl_lstm_state = self.rl_model.predict( | |
| self.obs, | |
| state=self.rl_lstm_state, | |
| episode_start=self.rl_episode_start, | |
| deterministic=True, | |
| ) | |
| action_idx = int(action.item() if hasattr(action, "item") else action) | |
| if not (0 <= action_idx < masks.shape[0] and bool(masks[action_idx])): | |
| valid = np.flatnonzero(masks) | |
| action_idx = int(valid[0]) if valid.size > 0 else 18 | |
| else: | |
| from sb3_contrib.common.maskable.utils import get_action_masks | |
| action, _ = self.rl_model.predict( | |
| self.obs, | |
| action_masks=get_action_masks(self.rl_env), | |
| deterministic=True, | |
| ) | |
| action_idx = int(action.item() if hasattr(action, "item") else action) | |
| self.obs, reward, terminated, truncated, info = self.rl_env.step(action_idx) | |
| done = bool(terminated or truncated) | |
| if self.model_type == "recurrent": | |
| self.rl_episode_start = np.array([done], dtype=bool) | |
| core_env = self.rl_env.core_env | |
| core_obs = core_env._build_observation() | |
| action_model, action_label = _decode_action_idx(action_idx) | |
| return { | |
| "step": self.step_idx, | |
| "day": core_obs.day, | |
| "action_type": action_label, | |
| "action_payload": action_model.model_dump(exclude_none=True, mode="json"), | |
| "action_index": action_idx, | |
| "reward": float(reward), | |
| "done": done, | |
| "backlog": core_obs.total_backlog, | |
| "completed": core_obs.total_completed, | |
| "sla_breaches": core_obs.total_sla_breaches, | |
| "fairness_gap": float( | |
| getattr(core_obs, "fairness_gap", getattr(core_obs, "fairness_index", 0.0)) or 0.0 | |
| ), | |
| "escalation_budget_remaining": core_obs.escalation_budget_remaining, | |
| "invalid_action": bool(info.get("invalid_action", False)), | |
| "last_action_error": info.get("last_action_error") or info.get("action_explanation"), | |
| "queue_rows": _queue_rows(core_obs), | |
| "decision_source": "trained_rl", | |
| "provider": "rl", | |
| "model_used": self.model_path or "trained_rl", | |
| "llm_attempts": 0, | |
| "llm_error": None, | |
| "llm_key_label": None, | |
| } | |
| def _finalize(self) -> None: | |
| if self.done: | |
| return | |
| self.done = True | |
| from app.graders import grade_episode | |
| if self.agent_mode == SimulationAgentMode.TRAINED_RL: | |
| final_state = self.rl_env.core_env.state() | |
| else: | |
| final_state = self.env.state() | |
| gr = grade_episode(final_state) | |
| self.score = float(gr.score) | |
| self.grader_name = gr.grader_name | |
| llm_steps = sum(1 for row in self.trace if row.get("decision_source") in {"llm", "llm_repaired"}) | |
| fallback_steps = sum( | |
| 1 for row in self.trace if row.get("decision_source") in {"heuristic_fallback", "auto_recovery_policy"} | |
| ) | |
| repaired_steps = sum( | |
| 1 for row in self.trace if row.get("decision_source") in {"llm_repaired", "auto_recovery_policy"} | |
| ) | |
| total_steps = max(1, len(self.trace)) | |
| invalid_actions = _safe_invalid_action_count(final_state) | |
| invalid_rate = float(invalid_actions) / float(total_steps) | |
| repaired_rate = float(repaired_steps) / float(total_steps) | |
| ranked_models: list[dict[str, Any]] = [] | |
| if self.llm_model_stats: | |
| for (provider, model), stat in self.llm_model_stats.items(): | |
| calls = int(stat.get("calls", 0)) | |
| if calls <= 0: | |
| continue | |
| ranked_models.append( | |
| { | |
| "provider": provider, | |
| "model": model, | |
| "calls": calls, | |
| "invalid_rate": float(stat.get("invalid", 0)) / max(1, calls), | |
| "repaired_rate": float(stat.get("repaired", 0)) / max(1, calls), | |
| } | |
| ) | |
| ranked_models.sort(key=lambda x: (x["invalid_rate"], x["repaired_rate"], -x["calls"])) | |
| self.summary = { | |
| "total_steps": getattr(final_state, "total_steps", len(self.trace)), | |
| "total_completed": getattr(final_state, "total_completed", 0), | |
| "total_backlog": getattr(final_state, "total_backlog", 0), | |
| "total_sla_breaches": getattr(final_state, "total_sla_breaches", 0), | |
| "fairness_gap": float(getattr(final_state, "fairness_gap", 0.0) or 0.0), | |
| "total_invalid_actions": invalid_actions, | |
| "invalid_action_rate": invalid_rate, | |
| "llm_steps": llm_steps, | |
| "heuristic_fallback_steps": fallback_steps, | |
| "llm_repaired_steps": repaired_steps, | |
| "repaired_action_rate": repaired_rate, | |
| "auto_switch_count": self.auto_switch_count, | |
| "last_switch_reason": self.last_switch_reason, | |
| "effective_max_steps": self.max_steps, | |
| "recommended_min_steps": _recommended_min_steps(self.task_id), | |
| } | |
| if self.agent_mode == SimulationAgentMode.LLM_INFERENCE: | |
| self.summary["llm_route"] = list(self.llm_route) | |
| self.summary["llm_model_performance"] = ranked_models | |
| if self.agent_mode == SimulationAgentMode.TRAINED_RL: | |
| self.summary["model_path"] = self.model_path | |
| self.summary["model_type"] = self.model_type | |
| def run_simulation( | |
| *, | |
| task_id: str, | |
| agent_mode: SimulationAgentMode, | |
| max_steps: int, | |
| seed: int | None, | |
| policy_name: str | None = None, | |
| model_path: str | None = None, | |
| model_type: Literal["maskable", "recurrent"] = "maskable", | |
| ) -> SimulationRun: | |
| session = LiveSimulationSession( | |
| task_id=task_id, | |
| agent_mode=agent_mode, | |
| max_steps=max_steps, | |
| seed=seed, | |
| policy_name=policy_name, | |
| model_path=model_path, | |
| model_type=model_type, | |
| ) | |
| try: | |
| while not session.done: | |
| session.step_once() | |
| return SimulationRun( | |
| task_id=session.task_id, | |
| agent_mode=session.agent_mode, | |
| seed=session.seed, | |
| total_reward=float(session.total_reward), | |
| score=float(session.score or 0.0), | |
| grader_name=str(session.grader_name or "unknown"), | |
| summary=dict(session.summary or {}), | |
| trace=list(session.trace), | |
| ) | |
| finally: | |
| session.close() | |
| def _decode_action_idx(action_idx: int) -> tuple[ActionModel, str]: | |
| try: | |
| from rl.feature_builder import ACTION_DECODE_TABLE | |
| except Exception: | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME), f"action_{action_idx}" | |
| row = ACTION_DECODE_TABLE.get(int(action_idx)) | |
| if row is None: | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME), f"action_{action_idx}" | |
| action_type, service, priority_mode, delta = row | |
| try: | |
| at = ActionType(str(action_type)) | |
| except Exception: | |
| return ActionModel(action_type=ActionType.ADVANCE_TIME), f"action_{action_idx}" | |
| if at == ActionType.SET_PRIORITY_MODE: | |
| action = _action_model_from_kwargs(at, priority_mode=priority_mode) | |
| elif at == ActionType.ASSIGN_CAPACITY: | |
| action = _action_model_from_kwargs(at, service=service, officer_delta=delta or 1) | |
| elif at == ActionType.REQUEST_MISSING_DOCUMENTS: | |
| action = _action_model_from_kwargs(at, service=service) | |
| elif at == ActionType.ESCALATE_SERVICE: | |
| action = _action_model_from_kwargs(at, service=service) | |
| elif at == ActionType.REALLOCATE_OFFICERS: | |
| src = _enum_service(service) | |
| action = ( | |
| _action_model_from_kwargs(at, service=src, target_service=src, officer_delta=delta or 1) | |
| if src is not None | |
| else ActionModel(action_type=ActionType.ADVANCE_TIME) | |
| ) | |
| else: | |
| action = ActionModel(action_type=ActionType.ADVANCE_TIME) | |
| return action, at.value |