OPENENV_RL_01 / app /engine.py
Siddharaj Shirke
deploy: fresh snapshot to Hugging Face Space
3eae4cc
from __future__ import annotations
import json
import os
import random
import re
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Optional
from openai import OpenAI
from app.event_engine import EventEngine
from app.models import (
ActionModel,
ActionType,
ApplicationCase,
DelayedEffect,
EventType,
IntakeChannel,
InternalSubstate,
ObservationModel,
PriorityMode,
QueueSnapshot,
ServiceType,
StageType,
)
from app.sector_profiles import get_sector_profile
from app.state_machine import can_advance
if TYPE_CHECKING:
from app.models import TaskConfig
LEGACY_NVIDIA_MODEL_POOL = [
"meta/llama-3.3-70b-instruct",
"qwen/qwen3-next-80b-a3b-instruct",
"moonshotai/kimi-k2-instruct-0905",
"meta/llama-3.1-405b-instruct",
"deepseek-ai/deepseek-v3.2",
"qwen/qwq-32b",
"mistralai/mixtral-8x22b-instruct-v0.1",
"google/gemma-3-27b-it",
"microsoft/phi-4-mini-instruct",
"meta/llama-3.1-8b-instruct",
]
_MODEL_CACHE: dict[tuple[str, str], Any] = {}
# ─────────────────────────────────────────────
# DAY RESULT
# ─────────────────────────────────────────────
class DayResult:
def __init__(self) -> None:
self.new_arrivals: int = 0
self.new_completions: int = 0
self.new_sla_breaches: int = 0
self.total_capacity_days: int = 0
self.idle_officer_days: int = 0
self.stage_advances: int = 0
self.newly_unblocked_missing: int = 0
self.newly_blocked_missing: int = 0
self.newly_unblocked_enrich: int = 0
self.field_verif_completed: int = 0
self.urgent_completed: int = 0
self.digital_arrivals: int = 0
self.active_events: list[EventType] = []
# ─────────────────────────────────────────────
# DAY SIMULATOR
# ─────────────────────────────────────────────
class DaySimulator:
"""
Core daily simulation engine.
Accepts TWO calling conventions so both env.py and tests work:
Convention A (tests):
DaySimulator(task_config=task, rng=rng, event_engine=engine)
Convention B (env.py legacy):
DaySimulator(seed=42, task_config=task, sector_registry={})
— in this case rng and event_engine are built internally.
"""
def __init__(
self,
task_config: "TaskConfig",
rng: Optional[random.Random] = None,
event_engine: Optional[EventEngine] = None,
seed: Optional[int] = None,
sector_registry: Optional[dict] = None,
) -> None:
self.task_config = task_config
self.task = task_config
if rng is not None:
self.rng = rng
elif seed is not None:
self.rng = random.Random(seed)
else:
self.rng = random.Random(task_config.seed)
if event_engine is not None:
self.event_engine = event_engine
else:
_seed = seed if seed is not None else task_config.seed
self.event_engine = EventEngine(
seed=_seed,
scenario_mode=task_config.scenario_mode,
)
self.sector_registry = sector_registry or {}
self.active_cases: list[ApplicationCase] = []
self.pending_effects: list[DelayedEffect] = []
self.case_counter: int = 0
def simulate_day(
self,
day: int,
active_cases: list[ApplicationCase],
completed_cases: list[ApplicationCase],
priority_mode: PriorityMode,
officer_allocations: dict,
) -> DayResult:
result = DayResult()
events = self.event_engine.get_events_for_day(day, self.task_config)
params = self.event_engine.apply_events(events, self.task_config)
result.active_events = list(params.active_events)
new_cases = self._spawn_arrivals(day, params, result)
active_cases.extend(new_cases)
effective_alloc = self._apply_officer_reduction(officer_allocations, params)
self._resolve_field_verification(day, active_cases, result)
self._resolve_doc_requests(day, active_cases, result)
newly_completed: list[ApplicationCase] = []
for service in self.task_config.enabled_services:
capacity = effective_alloc.get(service, effective_alloc.get(service.value, 0))
result.total_capacity_days += int(capacity)
service_cases = [
c
for c in active_cases
if c.service_type == service and not c.completed and not c.rejected
]
if not service_cases:
result.idle_officer_days += int(capacity)
continue
sorted_cases = self._sort_queue(service_cases, priority_mode)
for case in sorted_cases:
if capacity <= 0:
break
from app.state_machine import advance_case
advanced, final = advance_case(case, day)
if advanced:
capacity -= 1
result.stage_advances += 1
if final:
newly_completed.append(case)
if case.is_urgent:
result.urgent_completed += 1
if newly_completed:
done_ids = {c.case_id for c in newly_completed}
still_active = [c for c in active_cases if c.case_id not in done_ids]
active_cases.clear()
active_cases.extend(still_active)
completed_cases.extend(newly_completed)
result.new_completions = len(newly_completed)
for case in active_cases:
case.current_day = day
case.waiting_days += 1
if day > case.sla_deadline_day and not case.sla_breached:
case.sla_breached = True
result.new_sla_breaches += 1
return result
def _apply_officer_reduction(self, allocations: dict, params: Any) -> dict:
reduction = int(getattr(params, "officer_reduction", 0))
if reduction <= 0:
return dict(allocations)
effective = dict(allocations)
for _ in range(reduction):
target = max(effective, key=lambda k: effective[k], default=None)
if target is None or effective[target] <= 0:
break
effective[target] -= 1
return effective
def _spawn_arrivals(
self,
day: int,
params: Any,
result: DayResult,
) -> list[ApplicationCase]:
new_cases: list[ApplicationCase] = []
for service in self.task_config.enabled_services:
base_rate = self.task_config.arrival_rate_per_day.get(
service,
self.task_config.arrival_rate_per_day.get(service.value, 0.0),
)
effective_rate = float(base_rate) * float(getattr(params, "arrival_multiplier", 1.0))
count = int(effective_rate)
if self.rng.random() < (effective_rate - count):
count += 1
for _ in range(count):
case = self._new_case(service, day, params)
new_cases.append(case)
if case.intake_channel == IntakeChannel.DIGITAL:
result.digital_arrivals += 1
result.new_arrivals = len(new_cases)
return new_cases
def _new_case(self, service: ServiceType, day: int, params: Any) -> ApplicationCase:
self.case_counter += 1
profile = get_sector_profile(service)
sla_days = int(profile.sla_days * getattr(params, "sla_window_multiplier", 1.0))
sla_deadline_day = day + sla_days
digital_ratio = self.task_config.digital_intake_ratio
channel = (
IntakeChannel.DIGITAL
if self.rng.random() < digital_ratio
else IntakeChannel.PAPER
)
base_missing = profile.missing_docs_probability
override = (self.task_config.missing_docs_probability_override or {}).get(
service,
(self.task_config.missing_docs_probability_override or {}).get(service.value),
)
if override is not None:
base_missing = override
defect_rate = (
profile.doc_defect_rate_digital
if channel == IntakeChannel.DIGITAL
else profile.doc_defect_rate_paper
)
eff_missing = min(
1.0,
base_missing + getattr(params, "doc_defect_rate_boost", 0.0) * defect_rate,
)
has_missing = self.rng.random() < eff_missing
base_fv = profile.field_verification_probability
fv_override = (self.task_config.field_verification_probability_override or {}).get(
service,
(self.task_config.field_verification_probability_override or {}).get(service.value),
)
if fv_override is not None:
base_fv = fv_override
eff_fv = min(1.0, base_fv + getattr(params, "field_verification_boost", 0.0))
has_fv = self.rng.random() < eff_fv
field_completion_day = day + profile.field_verification_days if has_fv else None
from app.models import UrgencyProfile
urgency_profile = profile.urgency_profile
is_urgent = (
urgency_profile == UrgencyProfile.HIGH and self.rng.random() < 0.20
) or (
urgency_profile == UrgencyProfile.MODERATE and self.rng.random() < 0.08
)
return ApplicationCase(
case_id=f"case-{self.case_counter:06d}",
service_type=service,
arrival_day=day,
current_day=day,
sla_deadline_day=sla_deadline_day,
intake_channel=channel,
internal_substate=(
InternalSubstate.BLOCKED_MISSING_DOCS
if has_missing
else InternalSubstate.PRE_SCRUTINY
),
public_stage=StageType.SUBMISSION,
is_urgent=is_urgent,
has_missing_docs=has_missing,
field_verification_required=has_fv,
field_verification_completion_day=field_completion_day,
)
def _resolve_field_verification(
self,
day: int,
active_cases: list[ApplicationCase],
result: DayResult,
) -> None:
for case in active_cases:
if (
case.internal_substate == InternalSubstate.FIELD_VERIFICATION_PENDING
and case.field_verification_completion_day is not None
and day >= case.field_verification_completion_day
):
case.internal_substate = InternalSubstate.PRE_SCRUTINY
case.field_verification_completion_day = None
result.field_verif_completed += 1
def _resolve_doc_requests(
self,
day: int,
active_cases: list[ApplicationCase],
result: DayResult,
) -> None:
for case in active_cases:
if (
case.internal_substate == InternalSubstate.BLOCKED_MISSING_DOCS
and case.doc_resolution_day is not None
and day >= case.doc_resolution_day
):
case.internal_substate = InternalSubstate.PRE_SCRUTINY
case.doc_resolution_day = None
result.newly_unblocked_missing += 1
def _sort_queue(
self,
cases: list[ApplicationCase],
priority_mode: PriorityMode,
) -> list[ApplicationCase]:
eligible = [c for c in cases if can_advance(c)]
if priority_mode == PriorityMode.URGENT_FIRST:
return sorted(
eligible,
key=lambda c: (not c.is_urgent, -c.sla_risk, c.arrival_day),
)
if priority_mode == PriorityMode.OLDEST_FIRST:
return sorted(eligible, key=lambda c: c.arrival_day)
if priority_mode == PriorityMode.BACKLOG_CLEARANCE:
return sorted(
eligible,
key=lambda c: (-c.sla_risk, not c.is_urgent, c.arrival_day),
)
return sorted(
eligible,
key=lambda c: (
-c.sla_risk if c.sla_risk > 0.8 else 0,
not c.is_urgent,
c.arrival_day,
),
)
def build_queue_snapshot(
self,
service: ServiceType,
active_cases: list[ApplicationCase],
day: int,
) -> QueueSnapshot:
cases = [
c
for c in active_cases
if c.service_type == service and not c.completed and not c.rejected
]
stage_counts = {s.value: 0 for s in StageType}
for c in cases:
stage_counts[c.public_stage.value] = stage_counts.get(c.public_stage.value, 0) + 1
oldest_age = max((c.waiting_days for c in cases), default=0)
avg_wait = sum(c.waiting_days for c in cases) / len(cases) if cases else 0.0
sla_risk = sum(c.sla_risk for c in cases) / len(cases) if cases else 0.0
return QueueSnapshot(
service_type=service,
public_stage_counts=stage_counts,
total_pending=len(cases),
total_completed_today=0,
total_sla_breached=sum(1 for c in cases if c.sla_breached),
urgent_pending=sum(1 for c in cases if c.is_urgent),
blocked_missing_docs=sum(
1
for c in cases
if c.internal_substate == InternalSubstate.BLOCKED_MISSING_DOCS
),
field_verification_pending=sum(
1
for c in cases
if c.internal_substate == InternalSubstate.FIELD_VERIFICATION_PENDING
),
oldest_case_age_days=oldest_age,
avg_waiting_days=round(avg_wait, 2),
current_sla_risk=round(min(1.0, sla_risk), 3),
)
# ─────────────────────────────────────────────
# HIGH-LEVEL SIMULATION ORCHESTRATION
# ─────────────────────────────────────────────
class SimulationAgentMode(str, Enum):
BASELINE_POLICY = "baseline_policy"
LLM_INFERENCE = "llm_inference"
TRAINED_RL = "trained_rl"
@dataclass
class SimulationRun:
task_id: str
agent_mode: SimulationAgentMode
seed: int
total_reward: float
score: float
grader_name: str
summary: dict[str, Any]
trace: list[dict[str, Any]]
def _dedupe(values: list[str | None]) -> list[str]:
out: list[str] = []
for value in values:
if value is None:
continue
v = str(value).strip()
if v and v not in out:
out.append(v)
return out
def _env_csv_list(name: str) -> list[str]:
raw = os.getenv(name, "").strip()
if not raw:
return []
return [x.strip() for x in raw.split(",") if x.strip()]
def _extract_json_object(text: str) -> dict[str, Any] | None:
text = (text or "").strip()
if not text:
return None
try:
parsed = json.loads(text)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
except json.JSONDecodeError:
return None
return parsed if isinstance(parsed, dict) else None
def _enum_service(value: Any) -> ServiceType | None:
if value is None or value == "":
return None
if isinstance(value, ServiceType):
return value
try:
return ServiceType(str(value))
except Exception:
return None
def _enum_priority(value: Any) -> PriorityMode | None:
if value is None or value == "":
return None
if isinstance(value, PriorityMode):
return value
try:
return PriorityMode(str(value))
except Exception:
return None
def _action_model_from_kwargs(action_type: ActionType, **kwargs: Any) -> ActionModel:
service = _enum_service(kwargs.get("service") or kwargs.get("service_target"))
target_service = _enum_service(kwargs.get("target_service"))
escalation_target = _enum_service(kwargs.get("escalation_target"))
priority_mode = _enum_priority(kwargs.get("priority_mode"))
officer_delta = kwargs.get("officer_delta")
case_id = kwargs.get("case_id")
candidates: list[dict[str, Any]] = []
if action_type == ActionType.ADVANCE_TIME:
candidates.append({"action_type": action_type})
elif action_type == ActionType.SET_PRIORITY_MODE:
candidates.extend(
[
{"action_type": action_type, "priority_mode": priority_mode},
]
)
elif action_type == ActionType.ASSIGN_CAPACITY:
if service is not None:
delta = max(1, int(officer_delta or 1))
candidates.extend(
[
{"action_type": action_type, "service": service, "officer_delta": delta},
{"action_type": action_type, "service_target": service, "officer_delta": delta},
{
"action_type": action_type,
"capacity_assignment": {service.value: delta},
},
]
)
elif action_type == ActionType.REQUEST_MISSING_DOCUMENTS:
if service is not None:
candidates.extend(
[
{"action_type": action_type, "service": service},
{"action_type": action_type, "service_target": service},
]
)
elif action_type == ActionType.ESCALATE_SERVICE:
svc = escalation_target or service
candidates.extend(
[
{"action_type": action_type, "service": svc, "case_id": case_id},
{"action_type": action_type, "service_target": svc, "case_id": case_id},
{"action_type": action_type, "escalation_target": svc, "case_id": case_id},
]
)
elif action_type == ActionType.REALLOCATE_OFFICERS:
if service is not None and target_service is not None:
delta = max(1, int(officer_delta or 1))
candidates.extend(
[
{
"action_type": action_type,
"service": service,
"target_service": target_service,
"officer_delta": delta,
},
{
"action_type": action_type,
"reallocation_delta": {
service.value: -delta,
target_service.value: delta,
},
},
]
)
for candidate in candidates:
try:
return ActionModel(**candidate)
except Exception:
continue
return ActionModel(action_type=ActionType.ADVANCE_TIME)
def _coerce_action(payload: dict[str, Any] | None) -> ActionModel:
if not payload:
return ActionModel(action_type=ActionType.ADVANCE_TIME)
raw_action_type = payload.get("action_type") or payload.get("actionType")
try:
action_type = ActionType(str(raw_action_type))
except Exception:
return ActionModel(action_type=ActionType.ADVANCE_TIME)
service = payload.get("service") or payload.get("service_target") or payload.get("serviceTarget")
target_service = payload.get("target_service") or payload.get("targetService")
escalation_target = payload.get("escalation_target") or payload.get("escalationTarget")
priority_mode = payload.get("priority_mode") or payload.get("priorityMode")
officer_delta = payload.get("officer_delta") or payload.get("officerDelta")
case_id = payload.get("case_id") or payload.get("caseId")
if action_type == ActionType.ASSIGN_CAPACITY and not service:
assignment = payload.get("capacity_assignment") or {}
if isinstance(assignment, dict) and assignment:
service, officer_delta = next(iter(assignment.items()))
if action_type == ActionType.REALLOCATE_OFFICERS and (not service or not target_service):
delta_map = payload.get("reallocation_delta") or {}
if isinstance(delta_map, dict) and len(delta_map) >= 2:
negatives = [k for k, v in delta_map.items() if int(v) < 0]
positives = [k for k, v in delta_map.items() if int(v) > 0]
if negatives and positives:
service = negatives[0]
target_service = positives[0]
officer_delta = abs(int(delta_map[service]))
return _action_model_from_kwargs(
action_type,
service=service,
target_service=target_service,
escalation_target=escalation_target,
priority_mode=priority_mode,
officer_delta=officer_delta,
case_id=case_id,
)
def _recommended_min_steps(task_id: str) -> int:
if task_id == "cross_department_hard":
return 70
if task_id == "mixed_urgency_medium":
return 60
return 40
def _queue_snapshot_iter(obs: ObservationModel) -> list[Any]:
raw = getattr(obs, "queue_snapshots", [])
if isinstance(raw, dict):
return list(raw.values())
if isinstance(raw, list):
return list(raw)
try:
return list(raw)
except Exception:
return []
def _queue_service(q: Any) -> ServiceType | None:
return _enum_service(getattr(q, "service", None) or getattr(q, "service_type", None))
def _queue_active_cases(q: Any) -> int:
return int(getattr(q, "active_cases", getattr(q, "total_pending", 0)) or 0)
def _queue_missing_docs(q: Any) -> int:
return int(getattr(q, "missing_docs_cases", getattr(q, "blocked_missing_docs", 0)) or 0)
def _queue_urgent_cases(q: Any) -> int:
return int(getattr(q, "urgent_cases", getattr(q, "urgent_pending", 0)) or 0)
def _queue_breached_cases(q: Any) -> int:
return int(getattr(q, "breached_cases", getattr(q, "total_sla_breached", 0)) or 0)
def _queue_avg_age(q: Any) -> float:
if hasattr(q, "avg_age_days"):
return float(getattr(q, "avg_age_days") or 0.0)
if hasattr(q, "oldest_case_age_days"):
return float(getattr(q, "oldest_case_age_days") or 0.0)
return float(getattr(q, "avg_waiting_days", 0.0) or 0.0)
def _queue_rows(obs: ObservationModel) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
for q in _queue_snapshot_iter(obs):
service = _queue_service(q)
if service is None:
continue
rows.append(
{
"service": service.value,
"active_cases": _queue_active_cases(q),
"missing_docs_cases": _queue_missing_docs(q),
"urgent_cases": _queue_urgent_cases(q),
"breached_cases": _queue_breached_cases(q),
"avg_age_days": _queue_avg_age(q),
}
)
return rows
def _pool_allocations(obs: ObservationModel) -> dict[Any, Any]:
pool = getattr(obs, "officer_pool", None)
if pool is None:
return {}
return getattr(pool, "allocations", getattr(pool, "allocated", {})) or {}
def _reserve_officers(obs: ObservationModel) -> int:
pool = getattr(obs, "officer_pool", None)
if pool is None:
return 0
for name in ("reserve_officers", "idle_officers", "available_officers"):
if hasattr(pool, name):
try:
return int(getattr(pool, name) or 0)
except Exception:
pass
return 0
def _alloc_for(obs: ObservationModel, service: ServiceType) -> int:
allocs = _pool_allocations(obs)
raw = allocs.get(service)
if raw is None:
raw = allocs.get(service.value, 0)
return int(raw or 0)
def _top_backlog_service(
obs: ObservationModel,
*,
exclude: ServiceType | None = None,
) -> ServiceType | None:
ranked: list[Any] = []
for q in _queue_snapshot_iter(obs):
service = _queue_service(q)
if service is None or service == exclude:
continue
ranked.append(q)
if not ranked:
return None
ranked.sort(
key=lambda q: (
_queue_active_cases(q) + (2 * _queue_breached_cases(q)) + _queue_urgent_cases(q),
_queue_avg_age(q),
),
reverse=True,
)
return _queue_service(ranked[0])
def _service_with_missing_docs(obs: ObservationModel) -> ServiceType | None:
candidates = [q for q in _queue_snapshot_iter(obs) if _queue_missing_docs(q) > 0]
if not candidates:
return None
candidates.sort(key=lambda q: (_queue_missing_docs(q), _queue_active_cases(q)), reverse=True)
return _queue_service(candidates[0])
def _service_with_officers(obs: ObservationModel) -> ServiceType | None:
services = [s for s in (_queue_service(q) for q in _queue_snapshot_iter(obs)) if s is not None]
services.sort(key=lambda s: _alloc_for(obs, s), reverse=True)
for service in services:
if _alloc_for(obs, service) > 0:
return service
return None
def _compute_action_mask(obs: ObservationModel) -> dict[ActionType, bool]:
has_reserve = _reserve_officers(obs) > 0
snapshots = _queue_snapshot_iter(obs)
has_missing = any(_queue_missing_docs(q) > 0 for q in snapshots)
has_backlog = any(_queue_active_cases(q) > 0 for q in snapshots)
has_budget = int(getattr(obs, "escalation_budget_remaining", 0) or 0) > 0
staffed_services = [q for q in snapshots if (_queue_service(q) is not None and _alloc_for(obs, _queue_service(q)) > 0)]
can_reallocate = len(staffed_services) >= 1 and len(snapshots) >= 2
return {
ActionType.SET_PRIORITY_MODE: True,
ActionType.ADVANCE_TIME: True,
ActionType.ASSIGN_CAPACITY: has_reserve and has_backlog,
ActionType.REQUEST_MISSING_DOCUMENTS: has_missing,
ActionType.ESCALATE_SERVICE: has_budget and has_backlog,
ActionType.REALLOCATE_OFFICERS: can_reallocate,
}
def _masked_action_type_hints(obs: ObservationModel) -> tuple[list[str], list[str]]:
mask = _compute_action_mask(obs)
allowed = [k.value for k, ok in mask.items() if ok]
blocked = [k.value for k, ok in mask.items() if not ok]
return allowed, blocked
def _best_high_impact_action(obs: ObservationModel) -> tuple[ActionModel, str]:
top_backlog = _top_backlog_service(obs)
top_missing = _service_with_missing_docs(obs)
if _reserve_officers(obs) > 0 and top_backlog is not None:
return (
_action_model_from_kwargs(
ActionType.ASSIGN_CAPACITY,
service=top_backlog,
officer_delta=1,
),
"high-impact: assign reserve capacity to top backlog service",
)
if top_missing is not None:
return (
_action_model_from_kwargs(
ActionType.REQUEST_MISSING_DOCUMENTS,
service=top_missing,
),
"high-impact: clear missing-document bottleneck",
)
if int(getattr(obs, "escalation_budget_remaining", 0) or 0) > 0:
hot = sorted(
_queue_snapshot_iter(obs),
key=lambda q: (_queue_breached_cases(q), _queue_active_cases(q), _queue_urgent_cases(q)),
reverse=True,
)
if hot and (_queue_breached_cases(hot[0]) > 0 or _queue_active_cases(hot[0]) > 0):
service = _queue_service(hot[0])
if service is not None:
return (
_action_model_from_kwargs(
ActionType.ESCALATE_SERVICE,
service=service,
),
"high-impact: escalate highest SLA-risk service",
)
source = _service_with_officers(obs)
if source is not None and _alloc_for(obs, source) > 0:
target = _top_backlog_service(obs, exclude=source)
if target is not None and target != source:
return (
_action_model_from_kwargs(
ActionType.REALLOCATE_OFFICERS,
service=source,
target_service=target,
officer_delta=1,
),
"high-impact: reallocate one officer toward highest backlog",
)
return ActionModel(action_type=ActionType.ADVANCE_TIME), "fallback: no high-impact action available"
def _repair_action_for_observation(
action: ActionModel,
obs: ObservationModel,
) -> tuple[ActionModel, str | None]:
mask = _compute_action_mask(obs)
at = action.action_type
if not bool(mask.get(at, True)):
fallback, why = _best_high_impact_action(obs)
return fallback, f"masked {at.value}; {why}"
if at == ActionType.ADVANCE_TIME:
return action, None
if at == ActionType.SET_PRIORITY_MODE:
if getattr(action, "priority_mode", None) is None:
return (
_action_model_from_kwargs(
ActionType.SET_PRIORITY_MODE,
priority_mode=PriorityMode.BACKLOG_CLEARANCE,
),
"missing priority_mode, defaulted to backlog_clearance",
)
return action, None
if at == ActionType.ASSIGN_CAPACITY:
reserve = _reserve_officers(obs)
if reserve <= 0:
fallback, why = _best_high_impact_action(obs)
return fallback, f"reserve officers exhausted; {why}"
service = _enum_service(getattr(action, "service", None) or getattr(action, "service_target", None)) or _top_backlog_service(obs)
if service is None:
fallback, why = _best_high_impact_action(obs)
return fallback, f"no service available for assign_capacity; {why}"
delta = max(1, int(getattr(action, "officer_delta", 1) or 1))
delta = min(delta, reserve)
repaired = _action_model_from_kwargs(
ActionType.ASSIGN_CAPACITY,
service=service,
officer_delta=delta,
)
return repaired, "repaired assign_capacity payload"
if at == ActionType.REQUEST_MISSING_DOCUMENTS:
service = _enum_service(getattr(action, "service", None) or getattr(action, "service_target", None)) or _service_with_missing_docs(obs)
if service is None:
fallback, why = _best_high_impact_action(obs)
return fallback, f"no missing-doc queue available; {why}"
repaired = _action_model_from_kwargs(
ActionType.REQUEST_MISSING_DOCUMENTS,
service=service,
)
return repaired, "repaired request_missing_documents payload"
if at == ActionType.ESCALATE_SERVICE:
if int(getattr(obs, "escalation_budget_remaining", 0) or 0) <= 0:
fallback, why = _best_high_impact_action(obs)
return fallback, f"escalation budget exhausted; {why}"
service = (
_enum_service(getattr(action, "service", None))
or _enum_service(getattr(action, "service_target", None))
or _enum_service(getattr(action, "escalation_target", None))
or _top_backlog_service(obs)
)
case_id = getattr(action, "case_id", None)
if service is None and case_id is None:
fallback, why = _best_high_impact_action(obs)
return fallback, f"no escalation target available; {why}"
repaired = _action_model_from_kwargs(
ActionType.ESCALATE_SERVICE,
service=service,
case_id=case_id,
)
return repaired, "repaired escalate_service payload"
if at == ActionType.REALLOCATE_OFFICERS:
source = _enum_service(getattr(action, "service", None) or getattr(action, "service_target", None)) or _service_with_officers(obs)
if source is None:
fallback, why = _best_high_impact_action(obs)
return fallback, f"no staffed source service; {why}"
source_alloc = _alloc_for(obs, source)
if source_alloc <= 0:
source = _service_with_officers(obs)
source_alloc = _alloc_for(obs, source) if source is not None else 0
if source is None or source_alloc <= 0:
fallback, why = _best_high_impact_action(obs)
return fallback, f"insufficient source officers; {why}"
target = _enum_service(getattr(action, "target_service", None))
if target is None or target == source:
target = _top_backlog_service(obs, exclude=source)
if target is None or target == source:
fallback, why = _best_high_impact_action(obs)
return fallback, f"missing distinct target_service; {why}"
delta = max(1, int(getattr(action, "officer_delta", 1) or 1))
delta = min(delta, source_alloc)
repaired = _action_model_from_kwargs(
ActionType.REALLOCATE_OFFICERS,
service=source,
target_service=target,
officer_delta=delta,
)
return repaired, "repaired reallocate_officers payload"
return action, None
def _model_label_for_mode(agent_mode: SimulationAgentMode) -> str:
if agent_mode == SimulationAgentMode.BASELINE_POLICY:
return "baseline_policy"
if agent_mode == SimulationAgentMode.TRAINED_RL:
return "trained_rl"
return os.getenv("MODEL_NAME", "llm_inference")
def _log_step_line(step_row: dict[str, Any]) -> str:
done = "true" if bool(step_row.get("done")) else "false"
error = step_row.get("last_action_error") or "null"
action = json.dumps(step_row.get("action_payload", {}), separators=(",", ":"))
source = step_row.get("decision_source") or "unknown"
model = step_row.get("model_used") or "null"
repair = step_row.get("repair_note") or "null"
switch_note = step_row.get("switch_note") or "null"
return (
f"[STEP] step={step_row.get('step', 0)} action={action} "
f"reward={float(step_row.get('reward', 0.0)):.2f} done={done} "
f"error={error} source={source} model={model} repair={repair} switch={switch_note}"
)
def _resolve_model_path_or_raise(model_path: str) -> str:
p = Path(model_path).expanduser()
if not p.is_absolute():
p = (Path.cwd() / p).resolve()
if p.is_dir():
candidates = [
p / "best_model.zip",
p / "model.zip",
p / "checkpoint.zip",
]
zip_files = sorted(p.glob("*.zip"))
candidates.extend(zip_files)
for candidate in candidates:
if candidate.exists():
return str(candidate)
if p.exists():
return str(p)
raise FileNotFoundError(f"Model path not found: {model_path}")
def _load_model_cached_or_raise(model_abs: str, model_type: Literal["maskable", "recurrent"]) -> Any:
key = (model_abs, model_type)
if key in _MODEL_CACHE:
return _MODEL_CACHE[key]
if model_type == "recurrent":
from sb3_contrib import RecurrentPPO
model = RecurrentPPO.load(model_abs)
else:
try:
from sb3_contrib import MaskablePPO
model = MaskablePPO.load(model_abs)
except Exception:
from stable_baselines3 import PPO
model = PPO.load(model_abs)
_MODEL_CACHE[key] = model
return model
def _safe_invalid_action_count(final_state: Any) -> int:
if hasattr(final_state, "total_invalid_actions"):
return int(getattr(final_state, "total_invalid_actions") or 0)
metrics = getattr(final_state, "metrics", None)
if metrics is not None and hasattr(metrics, "total_invalid_actions"):
return int(getattr(metrics, "total_invalid_actions") or 0)
return 0
class LiveSimulationSession:
def __init__(
self,
*,
task_id: str,
agent_mode: SimulationAgentMode,
max_steps: int,
seed: int | None,
policy_name: str | None = None,
model_path: str | None = None,
model_type: Literal["maskable", "recurrent"] = "maskable",
) -> None:
self.task_id = task_id
self.agent_mode = agent_mode
recommended = _recommended_min_steps(task_id)
self.max_steps = max(int(max_steps), int(recommended)) if agent_mode == SimulationAgentMode.LLM_INFERENCE else int(max_steps)
self.seed = int(seed if seed is not None else random.randint(1, 999999))
self.policy_name = policy_name or "backlog_clearance"
self.model_path = model_path
self.model_type = model_type
self.trace: list[dict[str, Any]] = []
self.total_reward = 0.0
self.step_idx = 0
self.done = False
self.summary: dict[str, Any] | None = None
self.score: float | None = None
self.grader_name: str | None = None
self.env: Any = None
self.obs: ObservationModel | Any = None
self.policy: Any = None
self.rl_env: Any = None
self.rl_model: Any = None
self.rl_lstm_state: Any = None
self.rl_episode_start: Any = None
self.llm_runtimes: list[dict[str, Any]] = []
self.llm_route: list[str] = []
self.llm_model_stats: dict[tuple[str, str], dict[str, Any]] = {}
self.consecutive_failure_steps = 0
self.recovery_steps_remaining = 0
self.auto_switch_count = 0
self.last_switch_reason: str | None = None
if self.agent_mode == SimulationAgentMode.TRAINED_RL:
self._init_trained()
else:
self._init_core()
def start_line(self) -> dict[str, Any]:
return {
"log": (
f"[START] task={self.task_id} env=gov-workflow-openenv "
f"model={_model_label_for_mode(self.agent_mode)}"
),
"observation": self.obs
}
def _init_core(self) -> None:
from app.baselines import POLICIES, backlog_clearance_policy
from app.env import GovWorkflowEnv
self.env = GovWorkflowEnv(task_id=self.task_id)
self.obs, _ = self.env.reset(seed=self.seed)
if self.agent_mode == SimulationAgentMode.BASELINE_POLICY:
self.policy = POLICIES.get(self.policy_name, backlog_clearance_policy)
else:
self.policy = self._llm_action_with_meta
self._init_llm_runtimes()
def _init_llm_runtimes(self) -> None:
openai_base = os.getenv("API_BASE_URL") or os.getenv("OPENAI_API_BASE_URL") or "https://api.openai.com/v1"
nvidia_base = os.getenv("NVIDIA_API_BASE_URL", "https://integrate.api.nvidia.com/v1")
openai_keys = _dedupe(
[
os.getenv("HF_TOKEN"),
os.getenv("OPENAI_API_KEY"),
os.getenv("API_KEY"),
]
)
nvidia_keys = _dedupe(
[
os.getenv("NVIDIA_API_KEY"),
os.getenv("NVIDIA_API_KEY_2"),
]
)
openai_models = _dedupe(
[
os.getenv("MODEL_NAME", "meta/llama-3.3-70b-instruct"),
*_env_csv_list("MODEL_FALLBACKS"),
]
)
nvidia_models = _dedupe(
[
os.getenv("NVIDIA_MODEL"),
*_env_csv_list("NVIDIA_MODEL_FALLBACKS"),
*LEGACY_NVIDIA_MODEL_POOL,
]
)
runtimes: list[dict[str, Any]] = []
if openai_keys and openai_models:
clients: list[tuple[OpenAI, str]] = []
for idx, key in enumerate(openai_keys, start=1):
try:
clients.append(
(
OpenAI(base_url=openai_base, api_key=key, timeout=8.0, max_retries=0),
f"openai_key_{idx}",
)
)
except Exception:
continue
if clients:
runtimes.append(
{
"provider": "openai-compatible",
"base_url": openai_base,
"clients": clients,
"models": openai_models,
}
)
if nvidia_keys and nvidia_models:
clients = []
for idx, key in enumerate(nvidia_keys, start=1):
try:
clients.append(
(
OpenAI(base_url=nvidia_base, api_key=key, timeout=8.0, max_retries=0),
f"nvidia_key_{idx}",
)
)
except Exception:
continue
if clients:
runtimes.append(
{
"provider": "nvidia",
"base_url": nvidia_base,
"clients": clients,
"models": nvidia_models,
}
)
self.llm_runtimes = runtimes
self.llm_model_stats = {}
for runtime in runtimes:
provider = str(runtime.get("provider"))
for model in runtime.get("models", []):
self.llm_model_stats[(provider, str(model))] = {
"calls": 0,
"invalid": 0,
"repaired": 0,
"failures": 0,
"cooldown_until_step": 0,
}
openai_runtime = next((rt for rt in runtimes if rt.get("provider") == "openai-compatible"), None)
nvidia_runtime = next((rt for rt in runtimes if rt.get("provider") == "nvidia"), None)
openai_route = (
f"openai-compatible ({len(openai_runtime['clients'])} keys, {len(openai_runtime['models'])} models)"
if openai_runtime is not None
else "openai-compatible (unavailable: missing API key/model)"
)
nvidia_route = (
f"nvidia ({len(nvidia_runtime['clients'])} keys, {len(nvidia_runtime['models'])} models)"
if nvidia_runtime is not None
else "nvidia (unavailable: missing API key/model)"
)
self.llm_route = [
openai_route,
nvidia_route,
"adaptive ranking: prefer models with lower invalid/repaired rates",
"heuristic fallback (backlog_clearance_policy)",
]
def _rank_runtime_models(self, provider: str, models: list[str]) -> list[str]:
def _score(model_name: str) -> tuple[float, int]:
stat = self.llm_model_stats.get((provider, model_name), {})
calls = max(1, int(stat.get("calls", 0)))
invalid_rate = float(stat.get("invalid", 0)) / calls
repaired_rate = float(stat.get("repaired", 0)) / calls
fail_rate = float(stat.get("failures", 0)) / calls
cooldown = int(stat.get("cooldown_until_step", 0))
cooldown_penalty = 1.0 if self.step_idx < cooldown else 0.0
return (
invalid_rate * 2.0 + repaired_rate * 1.25 + fail_rate * 1.5 + cooldown_penalty,
-calls,
)
return sorted([str(m) for m in models], key=_score)
def _llm_action_with_meta(self, obs: ObservationModel) -> tuple[ActionModel, dict[str, Any]]:
if self.recovery_steps_remaining > 0:
self.recovery_steps_remaining -= 1
action, why = _best_high_impact_action(obs)
return action, {
"decision_source": "auto_recovery_policy",
"provider": "heuristic",
"model_used": "backlog_clearance_policy",
"llm_attempts": 0,
"llm_error": None,
"llm_key_label": None,
"repair_note": why,
}
attempts = 0
last_error = ""
allowed_actions, blocked_actions = _masked_action_type_hints(obs)
schema_hint = {
"required_fields": {
"set_priority_mode": ["action_type", "priority_mode"],
"assign_capacity": ["action_type", "service", "officer_delta"],
"request_missing_documents": ["action_type", "service"],
"escalate_service": ["action_type", "service"],
"advance_time": ["action_type"],
"reallocate_officers": ["action_type", "service", "target_service", "officer_delta"],
},
"allowed_priority_mode": [m.value for m in PriorityMode],
"allowed_services": [s.value for s in ServiceType],
}
system_prompt = (
"You are controlling a government workflow simulator. "
"Return exactly one JSON object only. No markdown. No explanation. "
"Allowed action_type: set_priority_mode, assign_capacity, request_missing_documents, "
"escalate_service, advance_time, reallocate_officers. "
"Rules: "
"1) reallocate_officers requires service + target_service + officer_delta>0 and source!=target. "
"2) assign_capacity requires service + officer_delta>0. "
"3) request_missing_documents requires service with missing_docs_cases>0. "
"4) set_priority_mode requires priority_mode in [urgent_first, oldest_first, balanced, backlog_clearance]. "
"5) Always prefer high-impact actions that reduce backlog/SLA risk over no-op loops. "
"Use lowercase enum values."
)
user_prompt = (
"Observation:\n"
f"{obs.model_dump_json() if hasattr(obs, 'model_dump_json') else json.dumps(getattr(obs, 'dict', lambda: {})())}\n"
f"Allowed action types now: {allowed_actions}\n"
f"Blocked action types now: {blocked_actions}\n"
f"Action schema hints: {json.dumps(schema_hint, separators=(',', ':'))}\n"
f"Last action validity: {getattr(obs, 'last_action_valid', True)}\n"
f"Last action message: {getattr(obs, 'last_action_message', '')}\n"
"Return action JSON."
)
for runtime in self.llm_runtimes:
provider = str(runtime["provider"])
ranked_models = self._rank_runtime_models(provider, list(runtime["models"]))
for client, key_label in runtime["clients"]:
for model in ranked_models:
attempts += 1
stat_key = (provider, model)
try:
out = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=0.0,
max_tokens=200,
stream=False,
)
content = (out.choices[0].message.content or "").strip()
action = _coerce_action(_extract_json_object(content))
if stat_key in self.llm_model_stats:
self.llm_model_stats[stat_key]["calls"] += 1
return action, {
"decision_source": "llm",
"provider": provider,
"model_used": model,
"llm_attempts": attempts,
"llm_error": None,
"llm_key_label": key_label,
}
except Exception as exc:
last_error = str(exc)
stat = self.llm_model_stats.get(stat_key)
if stat is not None:
stat["calls"] += 1
stat["failures"] += 1
if stat["failures"] >= 2:
stat["cooldown_until_step"] = self.step_idx + 5
continue
action, why = _best_high_impact_action(obs)
if not self.llm_runtimes:
last_error = "No LLM credentials configured."
return action, {
"decision_source": "heuristic_fallback",
"provider": "heuristic",
"model_used": "backlog_clearance_policy",
"llm_attempts": attempts,
"llm_error": last_error or None,
"llm_key_label": None,
"repair_note": why,
}
def _init_trained(self) -> None:
import numpy as np
from rl.gov_workflow_env import GovWorkflowGymEnv
if not self.model_path:
raise ValueError("model_path is required for trained_rl simulation.")
model_abs = _resolve_model_path_or_raise(self.model_path)
self.rl_model = _load_model_cached_or_raise(model_abs, self.model_type)
self.rl_env = GovWorkflowGymEnv(
task_id=self.task_id,
seed=self.seed,
hard_action_mask=True,
)
self.obs, _ = self.rl_env.reset(seed=self.seed)
self.rl_lstm_state = None
self.rl_episode_start = np.array([True], dtype=bool)
def step_once(self) -> tuple[dict[str, Any], str, bool]:
if self.done:
raise RuntimeError("Simulation already finished.")
self.step_idx += 1
row = self._step_trained() if self.agent_mode == SimulationAgentMode.TRAINED_RL else self._step_core()
self.trace.append(row)
self.total_reward += float(row["reward"])
step_log = _log_step_line(row)
if row["done"] or self.step_idx >= self.max_steps:
self._finalize()
row["done"] = True
return row, step_log, True
return row, step_log, False
def end_line(self) -> str:
if self.score is None:
return "[END] success=false steps=0 score=0.00 rewards="
rewards = ",".join(f"{float(x.get('reward', 0.0)):.2f}" for x in self.trace)
success = "true" if self.score >= 0.5 else "false"
return f"[END] success={success} steps={len(self.trace)} score={self.score:.2f} rewards={rewards}"
def step_line(self, action: dict | ActionModel) -> dict[str, Any]:
"""Test wrapper for executing an action and returning observation + reward."""
if isinstance(action, dict):
action = _coerce_action(action)
self.obs, reward, terminated, truncated, info = self.env.step(action)
return {"observation": self.obs, "reward": reward}
def snapshot(self) -> dict[str, Any]:
return {
"task_id": self.task_id,
"agent_mode": self.agent_mode.value,
"seed": self.seed,
"max_steps": self.max_steps,
"step_idx": self.step_idx,
"done": self.done,
"total_reward": float(self.total_reward),
"score": self.score,
"grader_name": self.grader_name,
"summary": self.summary,
"trace_len": len(self.trace),
"llm_route": list(self.llm_route),
}
def close(self) -> None:
try:
if self.env is not None and hasattr(self.env, "close"):
self.env.close()
except Exception:
pass
try:
if self.rl_env is not None and hasattr(self.rl_env, "close"):
self.rl_env.close()
except Exception:
pass
def _step_core(self) -> dict[str, Any]:
if self.env is None:
raise RuntimeError("Core simulation env not initialized.")
if self.agent_mode == SimulationAgentMode.BASELINE_POLICY:
action = self.policy(self.obs)
meta = {
"decision_source": "baseline_policy",
"provider": "local_policy",
"model_used": self.policy_name,
"llm_attempts": 0,
"llm_error": None,
"llm_key_label": None,
}
else:
raw_decision = self.policy(self.obs)
if isinstance(raw_decision, tuple) and len(raw_decision) == 2:
action, meta = raw_decision
else:
action, meta = raw_decision, {}
if not isinstance(meta, dict):
meta = {}
if not isinstance(action, ActionModel):
if isinstance(action, dict):
action = _coerce_action(action)
else:
action = ActionModel(action_type=ActionType.ADVANCE_TIME)
meta["repair_note"] = "non-action output from llm policy, coerced to advance_time"
allowed_mask = _compute_action_mask(self.obs)
if not bool(allowed_mask.get(action.action_type, True)):
masked_fallback, why = _best_high_impact_action(self.obs)
action = masked_fallback
if meta.get("decision_source") == "llm":
meta["decision_source"] = "llm_repaired"
meta["repair_note"] = f"action masked at runtime; {why}"
repaired_action, repair_note = _repair_action_for_observation(action, self.obs)
if repair_note:
action = repaired_action
if meta.get("decision_source") == "llm":
meta["decision_source"] = "llm_repaired"
meta["repair_note"] = repair_note
self.obs, reward, terminated, truncated, info = self.env.step(action)
done = bool(terminated or truncated)
last_action_error = getattr(info, "last_action_error", None)
if last_action_error is None:
last_action_error = getattr(info, "action_explanation", None)
row = {
"step": self.step_idx,
"day": self.obs.day,
"action_type": action.action_type.value,
"action_payload": action.model_dump(exclude_none=True, mode="json"),
"reward": float(reward),
"done": done,
"backlog": getattr(self.obs, "total_backlog", 0),
"completed": getattr(self.obs, "total_completed", 0),
"sla_breaches": getattr(self.obs, "total_sla_breaches", 0),
"fairness_gap": float(
getattr(self.obs, "fairness_gap", getattr(self.obs, "fairness_index", 0.0)) or 0.0
),
"escalation_budget_remaining": getattr(self.obs, "escalation_budget_remaining", 0),
"invalid_action": bool(getattr(info, "invalid_action", False)),
"last_action_error": last_action_error,
"queue_rows": _queue_rows(self.obs),
}
row.update(meta)
if self.agent_mode == SimulationAgentMode.LLM_INFERENCE:
is_repaired = row.get("decision_source") in {"llm_repaired", "auto_recovery_policy"}
is_invalid = bool(row.get("invalid_action")) or bool(row.get("last_action_error"))
model_used = str(row.get("model_used") or "")
provider = str(row.get("provider") or "")
stat_key = (provider, model_used)
stat = self.llm_model_stats.get(stat_key)
if stat is not None:
if is_repaired:
stat["repaired"] += 1
if is_invalid:
stat["invalid"] += 1
stat["failures"] += 1
else:
stat["failures"] = max(0, int(stat.get("failures", 0)) - 1)
is_failure_pattern = is_invalid or is_repaired
self.consecutive_failure_steps = self.consecutive_failure_steps + 1 if is_failure_pattern else 0
if self.consecutive_failure_steps >= 4:
if stat is not None:
stat["cooldown_until_step"] = self.step_idx + 6
self.recovery_steps_remaining = max(self.recovery_steps_remaining, 3)
self.auto_switch_count += 1
self.last_switch_reason = "repeated invalid/repaired pattern detected"
row["switch_note"] = "auto-switched to recovery policy and deprioritized failing model"
self.consecutive_failure_steps = 0
return row
def _step_trained(self) -> dict[str, Any]:
import numpy as np
masks = self.rl_env.action_masks()
if self.model_type == "recurrent":
action, self.rl_lstm_state = self.rl_model.predict(
self.obs,
state=self.rl_lstm_state,
episode_start=self.rl_episode_start,
deterministic=True,
)
action_idx = int(action.item() if hasattr(action, "item") else action)
if not (0 <= action_idx < masks.shape[0] and bool(masks[action_idx])):
valid = np.flatnonzero(masks)
action_idx = int(valid[0]) if valid.size > 0 else 18
else:
from sb3_contrib.common.maskable.utils import get_action_masks
action, _ = self.rl_model.predict(
self.obs,
action_masks=get_action_masks(self.rl_env),
deterministic=True,
)
action_idx = int(action.item() if hasattr(action, "item") else action)
self.obs, reward, terminated, truncated, info = self.rl_env.step(action_idx)
done = bool(terminated or truncated)
if self.model_type == "recurrent":
self.rl_episode_start = np.array([done], dtype=bool)
core_env = self.rl_env.core_env
core_obs = core_env._build_observation()
action_model, action_label = _decode_action_idx(action_idx)
return {
"step": self.step_idx,
"day": core_obs.day,
"action_type": action_label,
"action_payload": action_model.model_dump(exclude_none=True, mode="json"),
"action_index": action_idx,
"reward": float(reward),
"done": done,
"backlog": core_obs.total_backlog,
"completed": core_obs.total_completed,
"sla_breaches": core_obs.total_sla_breaches,
"fairness_gap": float(
getattr(core_obs, "fairness_gap", getattr(core_obs, "fairness_index", 0.0)) or 0.0
),
"escalation_budget_remaining": core_obs.escalation_budget_remaining,
"invalid_action": bool(info.get("invalid_action", False)),
"last_action_error": info.get("last_action_error") or info.get("action_explanation"),
"queue_rows": _queue_rows(core_obs),
"decision_source": "trained_rl",
"provider": "rl",
"model_used": self.model_path or "trained_rl",
"llm_attempts": 0,
"llm_error": None,
"llm_key_label": None,
}
def _finalize(self) -> None:
if self.done:
return
self.done = True
from app.graders import grade_episode
if self.agent_mode == SimulationAgentMode.TRAINED_RL:
final_state = self.rl_env.core_env.state()
else:
final_state = self.env.state()
gr = grade_episode(final_state)
self.score = float(gr.score)
self.grader_name = gr.grader_name
llm_steps = sum(1 for row in self.trace if row.get("decision_source") in {"llm", "llm_repaired"})
fallback_steps = sum(
1 for row in self.trace if row.get("decision_source") in {"heuristic_fallback", "auto_recovery_policy"}
)
repaired_steps = sum(
1 for row in self.trace if row.get("decision_source") in {"llm_repaired", "auto_recovery_policy"}
)
total_steps = max(1, len(self.trace))
invalid_actions = _safe_invalid_action_count(final_state)
invalid_rate = float(invalid_actions) / float(total_steps)
repaired_rate = float(repaired_steps) / float(total_steps)
ranked_models: list[dict[str, Any]] = []
if self.llm_model_stats:
for (provider, model), stat in self.llm_model_stats.items():
calls = int(stat.get("calls", 0))
if calls <= 0:
continue
ranked_models.append(
{
"provider": provider,
"model": model,
"calls": calls,
"invalid_rate": float(stat.get("invalid", 0)) / max(1, calls),
"repaired_rate": float(stat.get("repaired", 0)) / max(1, calls),
}
)
ranked_models.sort(key=lambda x: (x["invalid_rate"], x["repaired_rate"], -x["calls"]))
self.summary = {
"total_steps": getattr(final_state, "total_steps", len(self.trace)),
"total_completed": getattr(final_state, "total_completed", 0),
"total_backlog": getattr(final_state, "total_backlog", 0),
"total_sla_breaches": getattr(final_state, "total_sla_breaches", 0),
"fairness_gap": float(getattr(final_state, "fairness_gap", 0.0) or 0.0),
"total_invalid_actions": invalid_actions,
"invalid_action_rate": invalid_rate,
"llm_steps": llm_steps,
"heuristic_fallback_steps": fallback_steps,
"llm_repaired_steps": repaired_steps,
"repaired_action_rate": repaired_rate,
"auto_switch_count": self.auto_switch_count,
"last_switch_reason": self.last_switch_reason,
"effective_max_steps": self.max_steps,
"recommended_min_steps": _recommended_min_steps(self.task_id),
}
if self.agent_mode == SimulationAgentMode.LLM_INFERENCE:
self.summary["llm_route"] = list(self.llm_route)
self.summary["llm_model_performance"] = ranked_models
if self.agent_mode == SimulationAgentMode.TRAINED_RL:
self.summary["model_path"] = self.model_path
self.summary["model_type"] = self.model_type
def run_simulation(
*,
task_id: str,
agent_mode: SimulationAgentMode,
max_steps: int,
seed: int | None,
policy_name: str | None = None,
model_path: str | None = None,
model_type: Literal["maskable", "recurrent"] = "maskable",
) -> SimulationRun:
session = LiveSimulationSession(
task_id=task_id,
agent_mode=agent_mode,
max_steps=max_steps,
seed=seed,
policy_name=policy_name,
model_path=model_path,
model_type=model_type,
)
try:
while not session.done:
session.step_once()
return SimulationRun(
task_id=session.task_id,
agent_mode=session.agent_mode,
seed=session.seed,
total_reward=float(session.total_reward),
score=float(session.score or 0.0),
grader_name=str(session.grader_name or "unknown"),
summary=dict(session.summary or {}),
trace=list(session.trace),
)
finally:
session.close()
def _decode_action_idx(action_idx: int) -> tuple[ActionModel, str]:
try:
from rl.feature_builder import ACTION_DECODE_TABLE
except Exception:
return ActionModel(action_type=ActionType.ADVANCE_TIME), f"action_{action_idx}"
row = ACTION_DECODE_TABLE.get(int(action_idx))
if row is None:
return ActionModel(action_type=ActionType.ADVANCE_TIME), f"action_{action_idx}"
action_type, service, priority_mode, delta = row
try:
at = ActionType(str(action_type))
except Exception:
return ActionModel(action_type=ActionType.ADVANCE_TIME), f"action_{action_idx}"
if at == ActionType.SET_PRIORITY_MODE:
action = _action_model_from_kwargs(at, priority_mode=priority_mode)
elif at == ActionType.ASSIGN_CAPACITY:
action = _action_model_from_kwargs(at, service=service, officer_delta=delta or 1)
elif at == ActionType.REQUEST_MISSING_DOCUMENTS:
action = _action_model_from_kwargs(at, service=service)
elif at == ActionType.ESCALATE_SERVICE:
action = _action_model_from_kwargs(at, service=service)
elif at == ActionType.REALLOCATE_OFFICERS:
src = _enum_service(service)
action = (
_action_model_from_kwargs(at, service=src, target_service=src, officer_delta=delta or 1)
if src is not None
else ActionModel(action_type=ActionType.ADVANCE_TIME)
)
else:
action = ActionModel(action_type=ActionType.ADVANCE_TIME)
return action, at.value