Container-Port / server /environment.py
Draken1606's picture
fix: enforce strict exclusive score bounds across inference and env
f69544d
from __future__ import annotations
import random
import uuid
from dataclasses import dataclass
from typing import Any
from openenv.core.env_server import Environment, State
from openenv.core.env_server.types import EnvironmentMetadata
from models import ContainerAction, ContainerObservation
@dataclass(slots=True)
class Container:
id: str
priority: int
weight: float
DIFFICULTY_CONFIG = {
"easy": {
"n_stacks": 6,
"max_height": 4,
"n_containers": 20,
"retrieval_interval": 5,
"lookahead": 5,
"priority_weights": [0.4, 0.4, 0.2],
},
"medium": {
"n_stacks": 8,
"max_height": 5,
"n_containers": 35,
"retrieval_interval": 5,
"lookahead": 3,
"priority_weights": [0.33, 0.34, 0.33],
},
"hard": {
"n_stacks": 10,
"max_height": 6,
"n_containers": 50,
"retrieval_interval": 4,
"lookahead": 0,
"priority_weights": [0.25, 0.35, 0.40],
},
}
class ContainerYardEnvironment(Environment):
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self) -> None:
self._difficulty = "medium"
self._state = State(episode_id=str(uuid.uuid4()), step_count=0)
self._init_env("medium", seed=None)
def _init_env(self, difficulty: str, seed: int | None) -> None:
if difficulty not in DIFFICULTY_CONFIG:
difficulty = "medium"
self._difficulty = difficulty
cfg = DIFFICULTY_CONFIG[difficulty]
self.n_stacks = cfg["n_stacks"]
self.max_height = cfg["max_height"]
self.n_containers = cfg["n_containers"]
self.retrieval_interval = cfg["retrieval_interval"]
self.lookahead = cfg["lookahead"]
self.priority_weights = cfg["priority_weights"]
if seed is not None:
random.seed(seed)
self.stacks: list[list[Container]] = [[] for _ in range(self.n_stacks)]
self.rehandle_count = 0
self.total_reward = 0.0
self.done = False
self.manifest: list[Container] = self._generate_manifest()
self.retrieval_queue: list[str] = self._generate_retrieval_queue()
self.retrieval_pointer = 0
self.current_idx = 0
def _generate_manifest(self) -> list[Container]:
containers = []
for i in range(self.n_containers):
priority = random.choices([1, 2, 3], weights=self.priority_weights)[0]
containers.append(Container(
id=f"C{i:03d}",
priority=priority,
weight=round(random.uniform(5.0, 30.0), 1)
))
return containers
def _generate_retrieval_queue(self) -> list[str]:
ids_by_priority = {1: [], 2: [], 3: []}
for c in self.manifest:
ids_by_priority[c.priority].append(c.id)
for p in ids_by_priority:
random.shuffle(ids_by_priority[p])
queue = ids_by_priority[1] + ids_by_priority[2] + ids_by_priority[3]
return queue
def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
**kwargs: Any,
) -> ContainerObservation:
difficulty = kwargs.get("difficulty", "medium")
self._state = State(
episode_id=episode_id or str(uuid.uuid4()),
step_count=0,
)
self._init_env(difficulty, seed)
return self._observe(last_reward=0.0)
def step(
self,
action: ContainerAction | int,
timeout_s: float | None = None,
**kwargs: Any,
) -> ContainerObservation:
if self.done:
return self._observe(0.0)
if isinstance(action, int):
action = ContainerAction(stack_index=action)
stack_index = action.stack_index
if stack_index < 0 or stack_index >= self.n_stacks:
reward = -2.0
self.total_reward += reward
self._state.step_count += 1
return self._observe(reward)
if len(self.stacks[stack_index]) >= self.max_height:
reward = -2.0
self.total_reward += reward
self._state.step_count += 1
return self._observe(reward)
current = self.manifest[self.current_idx]
self.stacks[stack_index].append(current)
placement_reward = self._placement_reward(stack_index, current)
self.current_idx += 1
self._state.step_count += 1
retrieval_cost = 0.0
if self._state.step_count % self.retrieval_interval == 0:
cost, _ = self._trigger_retrieval()
retrieval_cost = cost
reward = placement_reward - retrieval_cost
self.total_reward += reward
self.done = (self.current_idx >= len(self.manifest))
return self._observe(reward)
def _placement_reward(self, stack_index: int, container: Container) -> float:
# stack_depth = zero-based index of the just-placed container
stack_depth = len(self.stacks[stack_index]) - 1
accessibility = (self.max_height - stack_depth) / self.max_height
priority_weight = (4 - container.priority) / 3.0 # priority 11.0, 20.67, 30.33
base = 0.3 * accessibility * priority_weight
# Bonus: high-priority container placed near top (accessible for fast retrieval)
if container.priority == 1 and stack_depth <= 1:
base += 0.15
# Penalty: placing lower-priority on top of higher-priority container (causes future rehandles)
if stack_depth > 0:
top_container = self.stacks[stack_index][-2] # container directly below
if container.priority > top_container.priority:
base -= 0.2 * (container.priority - top_container.priority) / 2.0
return round(base, 4)
def _trigger_retrieval(self) -> tuple[float, list[str]]:
total_cost = 0.0
done_ids = []
for _ in range(2):
if self.retrieval_pointer >= len(self.retrieval_queue):
break
target_id = self.retrieval_queue[self.retrieval_pointer]
self.retrieval_pointer += 1
cost = self._retrieve(target_id)
total_cost += cost
done_ids.append(target_id)
return total_cost, done_ids
def _retrieve(self, target_id: str) -> float:
for stack in self.stacks:
for i, c in enumerate(stack):
if c.id == target_id:
rehandles = len(stack) - 1 - i # containers above target
self.rehandle_count += rehandles
stack.pop(i)
return round(rehandles * 0.4, 4)
return 0.0 # container not yet in yard - no penalty
def _get_upcoming_retrievals(self) -> list[str]:
start = self.retrieval_pointer
end = min(start + self.lookahead, len(self.retrieval_queue))
return self.retrieval_queue[start:end]
@property
def state(self) -> State:
return self._state
def _observe(self, last_reward: float = 0.0) -> ContainerObservation:
stack_states = []
for s in self.stacks:
stack_states.append([{"id": c.id, "priority": c.priority} for c in s])
current = None
if self.current_idx < len(self.manifest):
c = self.manifest[self.current_idx]
current = {"id": c.id, "priority": c.priority, "weight": c.weight}
return ContainerObservation(
stack_states=stack_states,
current_container=current,
upcoming_retrievals=self._get_upcoming_retrievals(),
rehandle_count=self.rehandle_count,
step=self._state.step_count,
containers_remaining=len(self.manifest) - self.current_idx,
n_stacks=self.n_stacks,
max_height=self.max_height,
difficulty=self._difficulty,
last_reward=last_reward,
score=self.score(),
done=self.done,
)
def score(self) -> float:
"""Normalized score strictly in (0.0, 1.0). Based on actual retrievals attempted."""
n_retrieved = self.retrieval_pointer # only count retrievals that actually happened
worst_case = n_retrieved * (self.max_height - 1)
if worst_case == 0:
return 0.5 # no retrievals yet — neutral score
raw = 1.0 - self.rehandle_count / worst_case
# Clamp strictly inside (0, 1) — grader requires score != 0.0 and score != 1.0
score = max(0.01, min(raw, 0.99))
return round(score, 4)
def get_state(self) -> dict[str, Any]:
return self._observe().model_dump()
def get_metadata(self) -> EnvironmentMetadata:
return EnvironmentMetadata(
name="container-port-env",
description=(
"Container terminal yard environment where agents place incoming "
"containers into stacks to minimize rehandle cost during retrieval."
),
version="0.1.0",
)
ContainerYardEnv = ContainerYardEnvironment