Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass, field | |
| from enum import Enum | |
| from typing import Any, Dict, List, Literal, Optional | |
| from openenv_core.env_server.types import ( | |
| Action as BaseAction, | |
| Observation as BaseObservation, | |
| State as BaseState, | |
| ) | |
| # ============================================================================= | |
| # Action — extends OpenEnv Action | |
| # ============================================================================= | |
| class ActionType(str, Enum): | |
| """The four possible routing actions.""" | |
| ROUTE_TO_A = "route_to_a" | |
| ROUTE_TO_B = "route_to_b" | |
| ROUTE_TO_C = "route_to_c" | |
| SHED_LOAD = "shed_load" | |
| class Action(BaseAction): | |
| """ | |
| Agent action: route a request to a provider or shed load. | |
| Extends OpenEnv Action (which provides `metadata` field). | |
| """ | |
| action_type: Literal["route_to_a", "route_to_b", "route_to_c", "shed_load"] | |
| def __post_init__(self) -> None: | |
| if isinstance(self.action_type, str): | |
| self.action_type = ActionType(self.action_type) | |
| # ============================================================================= | |
| # Observation — extends OpenEnv Observation | |
| # ============================================================================= | |
| class Observation(BaseObservation): | |
| """ | |
| Agent-visible observation. ALL numeric fields are normalized to [0.0, 1.0]. | |
| Extends OpenEnv Observation (which provides `done`, `reward`, `metadata` fields). | |
| """ | |
| # Provider health (recent success rates) | |
| provider_a_status: float | |
| provider_b_status: float | |
| provider_c_status: float | |
| # Resource state | |
| budget_remaining: float | |
| queue_backlog: float | |
| system_latency: float | |
| # Episode progress | |
| step_count: float | |
| def __post_init__(self) -> None: | |
| for field_name in ( | |
| "provider_a_status", | |
| "provider_b_status", | |
| "provider_c_status", | |
| "budget_remaining", | |
| "queue_backlog", | |
| "system_latency", | |
| "step_count", | |
| ): | |
| setattr(self, field_name, max(0.0, min(1.0, getattr(self, field_name)))) | |
| # ============================================================================= | |
| # Internal State (raw units, for debugging / trace only) | |
| # ============================================================================= | |
| class ProviderState: | |
| """Internal state of a single provider in raw units.""" | |
| name: str | |
| base_reliability: float # initial reliability [0, 1] | |
| current_health: float # current health [0, 1] | |
| cost_per_request: float # dollars | |
| base_latency_ms: float # base latency in ms | |
| total_requests: int = 0 | |
| successful_requests: int = 0 | |
| def observed_success_rate(self) -> float: | |
| """Success rate from agent's perspective (windowed).""" | |
| if self.total_requests == 0: | |
| return self.base_reliability | |
| return self.successful_requests / self.total_requests | |
| class InternalState: | |
| """ | |
| Full internal state in raw units. NOT exposed to the agent. | |
| Used for manual trace, debugging, and the oracle policy. | |
| """ | |
| providers: Dict[str, ProviderState] = field(default_factory=dict) | |
| budget_dollars: float = 0.0 | |
| initial_budget_dollars: float = 0.0 | |
| queue_backlog_count: int = 0 | |
| max_queue_backlog: int = 10 | |
| last_latency_ms: float = 0.0 | |
| sla_ceiling_ms: float = 500.0 | |
| current_step: int = 0 | |
| max_steps: int = 20 | |
| episode_done: bool = False | |
| history: List[Dict[str, Any]] = field(default_factory=list) | |
| # Windowed success tracking (last N requests per provider) | |
| provider_window: Dict[str, List[bool]] = field(default_factory=dict) | |
| window_size: int = 5 | |
| # Probed providers: tracks which providers have been routed to at least once | |
| probed_providers: set = field(default_factory=set) | |
| # Resolved (jittered) degradation onsets for this episode | |
| actual_degradation_start: int = 0 | |
| actual_secondary_degradation_start: int = 999 | |
| def get_windowed_success_rate(self, provider_name: str) -> float: | |
| """Get success rate over the last `window_size` requests for a provider.""" | |
| window = self.provider_window.get(provider_name, []) | |
| if not window: | |
| return self.providers[provider_name].base_reliability | |
| return sum(window) / len(window) | |
| # ============================================================================= | |
| # Task Configuration | |
| # ============================================================================= | |
| class TaskConfig: | |
| """ | |
| Configuration for a task scenario. Passed to reset(scenario=config). | |
| NOT a subclass — just a data container. | |
| """ | |
| name: str | |
| description: str | |
| # Budget | |
| initial_budget: float = 5.0 # dollars | |
| # Provider costs (per request, dollars) | |
| cost_a: float = 0.01 | |
| cost_b: float = 0.05 | |
| cost_c: float = 0.10 | |
| # Provider base reliability | |
| reliability_a: float = 0.70 | |
| reliability_b: float = 0.90 | |
| reliability_c: float = 0.99 | |
| # Provider base latency (ms) | |
| latency_a: float = 100.0 | |
| latency_b: float = 150.0 | |
| latency_c: float = 200.0 | |
| # SLA | |
| sla_ceiling_ms: float = 500.0 | |
| # Degradation config (primary) | |
| degradation_start_step: int = 0 # step at which degradation begins | |
| degradation_rate: float = 0.0 # health reduction per step for provider A | |
| degradation_target: str = "A" # which provider degrades | |
| degradation_start_jitter: int = 0 # ±jitter applied per episode to degradation_start_step | |
| # Secondary degradation (for multi-provider scenarios) | |
| secondary_degradation_start_step: int = 999 # 999 = no secondary degradation | |
| secondary_degradation_rate: float = 0.0 | |
| secondary_degradation_target: str = "" # empty = no secondary degradation | |
| secondary_degradation_start_jitter: int = 0 # ±jitter applied per episode to secondary_degradation_start_step | |
| # Episode | |
| max_steps: int = 20 | |
| max_queue_backlog: int = 10 | |
| # Stochastic noise | |
| latency_noise_std: float = 30.0 # ms std dev added to base latency | |
| # ============================================================================= | |
| # OpenEnv State — extends BaseState | |
| # ============================================================================= | |
| class EnvState(BaseState): | |
| """ | |
| OpenEnv-compatible state object returned by the `state` property. | |
| Extends BaseState (which provides `episode_id`, `step_count` fields). | |
| """ | |
| scenario_name: str = "" | |
| is_done: bool = False | |