Spaces:
Sleeping
Sleeping
Shabista Sehar commited on
Commit ·
11f9523
0
Parent(s):
Initial: Container Port OpenEnv
Browse files- Dockerfile +14 -0
- README.md +71 -0
- client/__init__.py +1 -0
- client/container_env.py +37 -0
- inference.py +224 -0
- openenv.yaml +30 -0
- pyproject.toml +17 -0
- requirements.txt +8 -0
- server/__init__.py +1 -0
- server/__pycache__/__init__.cpython-313.pyc.1828594663216 +0 -0
- server/__pycache__/__init__.cpython-313.pyc.2347090562224 +0 -0
- server/__pycache__/environment.cpython-313.pyc.1828594650032 +0 -0
- server/__pycache__/environment.cpython-313.pyc.2347106154928 +0 -0
- server/__pycache__/models.cpython-313.pyc.2347106454640 +0 -0
- server/__pycache__/server.cpython-313.pyc.2347090514912 +0 -0
- server/environment.py +211 -0
- server/models.py +36 -0
- server/server.py +84 -0
- tests/__pycache__/test_env.cpython-313-pytest-9.0.2.pyc.10200 +0 -0
- tests/__pycache__/test_env.cpython-313-pytest-9.0.2.pyc.16772 +0 -0
- tests/test_env.py +110 -0
Dockerfile
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
COPY requirements.txt .
|
| 6 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 7 |
+
|
| 8 |
+
COPY . .
|
| 9 |
+
|
| 10 |
+
ENV PYTHONPATH=/app
|
| 11 |
+
|
| 12 |
+
EXPOSE 7860
|
| 13 |
+
|
| 14 |
+
CMD ["uvicorn", "server.server:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Container Port Environment
|
| 2 |
+
|
| 3 |
+
An OpenEnv-compatible RL environment for container yard management at a shipping terminal.
|
| 4 |
+
|
| 5 |
+
## Task
|
| 6 |
+
|
| 7 |
+
A ship arrives with N containers (priority 1=urgent, 2=normal, 3=low). The agent places each into
|
| 8 |
+
stacks. At regular intervals, specific containers are retrieved. If a target is buried under others,
|
| 9 |
+
each container above it is a **rehandle** — expensive in real port operations.
|
| 10 |
+
|
| 11 |
+
**Goal: minimize total rehandle operations across the episode.**
|
| 12 |
+
|
| 13 |
+
## Difficulty Levels
|
| 14 |
+
|
| 15 |
+
| Parameter | Easy | Medium | Hard |
|
| 16 |
+
|--------------------|----------|----------|----------|
|
| 17 |
+
| Stacks | 6 | 8 | 10 |
|
| 18 |
+
| Max stack height | 4 | 5 | 6 |
|
| 19 |
+
| Containers | 20 | 35 | 50 |
|
| 20 |
+
| Retrieval interval | every 5 | every 5 | every 4 |
|
| 21 |
+
| Lookahead shown | 5 | 3 | 0 |
|
| 22 |
+
|
| 23 |
+
## Reward
|
| 24 |
+
|
| 25 |
+
| Event | Reward |
|
| 26 |
+
|---|---|
|
| 27 |
+
| Accessible placement of priority-1 (near top) | up to +0.45 |
|
| 28 |
+
| General placement | +0.03 to +0.30 |
|
| 29 |
+
| Burying high-priority under low-priority | -0.10 to -0.20 |
|
| 30 |
+
| Invalid action (full stack / bad index) | -2.0 |
|
| 31 |
+
| Each rehandle at retrieval time | -0.40 |
|
| 32 |
+
|
| 33 |
+
## Score
|
| 34 |
+
|
| 35 |
+
`score = 1.0 - (actual_rehandles / worst_case_rehandles)`, in [0.0, 1.0].
|
| 36 |
+
|
| 37 |
+
## Setup
|
| 38 |
+
```bash
|
| 39 |
+
pip install -r requirements.txt
|
| 40 |
+
uvicorn server.server:app --host 0.0.0.0 --port 7860
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
## Run inference
|
| 44 |
+
```bash
|
| 45 |
+
# Greedy agent, all difficulties
|
| 46 |
+
python inference.py --difficulty all
|
| 47 |
+
|
| 48 |
+
# LLM agent (requires HF token in env)
|
| 49 |
+
export HF_TOKEN=hf_your_token_here
|
| 50 |
+
python inference.py --use-llm --difficulty all
|
| 51 |
+
|
| 52 |
+
# Against deployed HF Space
|
| 53 |
+
python inference.py --url https://YOUR_USERNAME-container-port-env.hf.space --difficulty all
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## Docker
|
| 57 |
+
```bash
|
| 58 |
+
docker build -t container-port-env .
|
| 59 |
+
docker run -p 7860:7860 container-port-env
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
## API
|
| 63 |
+
|
| 64 |
+
- `GET /ping` — health check
|
| 65 |
+
- `GET /health` — server stats
|
| 66 |
+
- `WS /ws` — WebSocket interface
|
| 67 |
+
|
| 68 |
+
WebSocket messages:
|
| 69 |
+
- `{"type": "reset", "difficulty": "easy"}` — start episode
|
| 70 |
+
- `{"type": "step", "action": {"stack_index": 2}}` — place container
|
| 71 |
+
- `{"type": "state"}` — get full state with score
|
client/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
client/container_env.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import websockets
|
| 3 |
+
from typing import Any, Dict, Tuple
|
| 4 |
+
|
| 5 |
+
class ContainerEnvClient:
|
| 6 |
+
"""Async client for Container Port OpenEnv."""
|
| 7 |
+
|
| 8 |
+
def __init__(self, base_url: str = "http://localhost:7860"):
|
| 9 |
+
ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
|
| 10 |
+
self.ws_url = ws_url.rstrip("/") + "/ws"
|
| 11 |
+
self._ws = None
|
| 12 |
+
|
| 13 |
+
async def __aenter__(self):
|
| 14 |
+
self._ws = await websockets.connect(self.ws_url)
|
| 15 |
+
return self
|
| 16 |
+
|
| 17 |
+
async def __aexit__(self, *args):
|
| 18 |
+
if self._ws:
|
| 19 |
+
await self._ws.close()
|
| 20 |
+
|
| 21 |
+
async def reset(self, difficulty: str = "medium") -> Dict[str, Any]:
|
| 22 |
+
await self._ws.send(json.dumps({"type": "reset", "difficulty": difficulty}))
|
| 23 |
+
resp = json.loads(await self._ws.recv())
|
| 24 |
+
return resp["observation"]
|
| 25 |
+
|
| 26 |
+
async def step(self, stack_index: int) -> Tuple[Dict, float, bool, Dict]:
|
| 27 |
+
await self._ws.send(json.dumps({
|
| 28 |
+
"type": "step",
|
| 29 |
+
"action": {"stack_index": stack_index}
|
| 30 |
+
}))
|
| 31 |
+
resp = json.loads(await self._ws.recv())
|
| 32 |
+
return resp["observation"], resp["reward"], resp["done"], resp.get("info", {})
|
| 33 |
+
|
| 34 |
+
async def state(self) -> Dict[str, Any]:
|
| 35 |
+
await self._ws.send(json.dumps({"type": "state"}))
|
| 36 |
+
resp = json.loads(await self._ws.recv())
|
| 37 |
+
return resp["state"]
|
inference.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Container Port OpenEnv — Baseline Inference Script
|
| 4 |
+
SST x Meta PyTorch OpenEnv Hackathon
|
| 5 |
+
|
| 6 |
+
Required environment variables (or set below):
|
| 7 |
+
HF_TOKEN - Your Hugging Face token
|
| 8 |
+
API_BASE_URL - LLM API endpoint (default: https://router.huggingface.co/v1)
|
| 9 |
+
MODEL_NAME - Model identifier (default: meta-llama/Llama-3.1-8B-Instruct)
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python inference.py
|
| 13 |
+
python inference.py --url https://YOUR_USERNAME-container-port-env.hf.space --difficulty all
|
| 14 |
+
python inference.py --difficulty easy
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
import json
|
| 20 |
+
import asyncio
|
| 21 |
+
import argparse
|
| 22 |
+
import websockets
|
| 23 |
+
from openai import OpenAI
|
| 24 |
+
|
| 25 |
+
# Required configuration variables
|
| 26 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 27 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
|
| 28 |
+
HF_TOKEN = os.getenv("HF_TOKEN", "") # set your HF token here or via env var
|
| 29 |
+
#
|
| 30 |
+
|
| 31 |
+
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _llm_client() -> OpenAI:
|
| 35 |
+
"""Return an OpenAI-compatible client pointed at HF Inference Router."""
|
| 36 |
+
return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def greedy_decide(obs: dict) -> int:
|
| 40 |
+
"""
|
| 41 |
+
Greedy heuristic agent — no LLM call.
|
| 42 |
+
Scores each valid stack by accessibility and priority compatibility.
|
| 43 |
+
"""
|
| 44 |
+
stacks = obs["stack_states"]
|
| 45 |
+
current = obs.get("current_container")
|
| 46 |
+
max_height = obs["max_height"]
|
| 47 |
+
upcoming = set(obs.get("upcoming_retrievals", []))
|
| 48 |
+
|
| 49 |
+
if current is None:
|
| 50 |
+
return 0
|
| 51 |
+
|
| 52 |
+
cur_priority = current["priority"]
|
| 53 |
+
best_stack, best_score = -1, float("-inf")
|
| 54 |
+
|
| 55 |
+
for i, stack in enumerate(stacks):
|
| 56 |
+
depth = len(stack)
|
| 57 |
+
if depth >= max_height:
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
score = 0.0
|
| 61 |
+
accessibility = (max_height - depth) / max_height
|
| 62 |
+
score += accessibility * (4 - cur_priority)
|
| 63 |
+
|
| 64 |
+
if depth > 0:
|
| 65 |
+
top_priority = stack[-1]["priority"]
|
| 66 |
+
if cur_priority > top_priority:
|
| 67 |
+
score -= 10.0 * (cur_priority - top_priority)
|
| 68 |
+
elif cur_priority < top_priority:
|
| 69 |
+
score += 3.0
|
| 70 |
+
|
| 71 |
+
if current["id"] in upcoming:
|
| 72 |
+
score += 5.0 * accessibility
|
| 73 |
+
|
| 74 |
+
if depth > 0:
|
| 75 |
+
score += 0.5
|
| 76 |
+
|
| 77 |
+
if score > best_score:
|
| 78 |
+
best_score = score
|
| 79 |
+
best_stack = i
|
| 80 |
+
|
| 81 |
+
if best_stack == -1:
|
| 82 |
+
for i, stack in enumerate(stacks):
|
| 83 |
+
if len(stack) < max_height:
|
| 84 |
+
return i
|
| 85 |
+
return best_stack
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def llm_decide(obs: dict) -> int:
|
| 89 |
+
"""Use HF-hosted LLM via OpenAI-compatible client to choose a stack."""
|
| 90 |
+
stacks = obs["stack_states"]
|
| 91 |
+
current = obs.get("current_container")
|
| 92 |
+
n_stacks = obs["n_stacks"]
|
| 93 |
+
max_height = obs["max_height"]
|
| 94 |
+
upcoming = obs.get("upcoming_retrievals", [])
|
| 95 |
+
difficulty = obs.get("difficulty", "medium")
|
| 96 |
+
|
| 97 |
+
stack_lines = []
|
| 98 |
+
for i, stack in enumerate(stacks):
|
| 99 |
+
if not stack:
|
| 100 |
+
stack_lines.append(f" Stack {i}: EMPTY (0/{max_height})")
|
| 101 |
+
else:
|
| 102 |
+
contents = ", ".join(f"{c['id']}(p{c['priority']})" for c in stack)
|
| 103 |
+
stack_lines.append(
|
| 104 |
+
f" Stack {i}: [{contents}] depth={len(stack)}/{max_height},"
|
| 105 |
+
f" top=priority-{stack[-1]['priority']}"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
prompt = (
|
| 109 |
+
f"You are an expert container yard planner.\n"
|
| 110 |
+
f"TASK: Place the incoming container into a stack to MINIMIZE future rehandle operations.\n"
|
| 111 |
+
f"RULE: When a container is retrieved, every container ON TOP of it must be moved (rehandle).\n"
|
| 112 |
+
f"Priority 1=URGENT (retrieved first), 2=Normal, 3=Low (retrieved last).\n\n"
|
| 113 |
+
f"DIFFICULTY: {difficulty}\n"
|
| 114 |
+
f"UPCOMING RETRIEVALS (next to be retrieved, in order): "
|
| 115 |
+
f"{upcoming if upcoming else 'Unknown (hard mode)'}\n\n"
|
| 116 |
+
f"CONTAINER TO PLACE: id={current['id']}, priority={current['priority']}, "
|
| 117 |
+
f"weight={current['weight']}kg\n\n"
|
| 118 |
+
f"STACK STATES (bottomtop):\n" + "\n".join(stack_lines) + "\n\n"
|
| 119 |
+
f"Respond with ONLY valid JSON: {{\"stack_index\": <integer 0-{n_stacks-1}>}}"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
client = _llm_client()
|
| 124 |
+
response = client.chat.completions.create(
|
| 125 |
+
model=MODEL_NAME,
|
| 126 |
+
max_tokens=64,
|
| 127 |
+
temperature=0.0,
|
| 128 |
+
messages=[{"role": "user", "content": prompt}],
|
| 129 |
+
)
|
| 130 |
+
text = response.choices[0].message.content.strip()
|
| 131 |
+
# strip markdown fences if model wraps in ```json ... ```
|
| 132 |
+
if "```" in text:
|
| 133 |
+
text = text.split("```")[1]
|
| 134 |
+
if text.startswith("json"):
|
| 135 |
+
text = text[4:]
|
| 136 |
+
decision = json.loads(text.strip())
|
| 137 |
+
idx = int(decision["stack_index"])
|
| 138 |
+
if 0 <= idx < n_stacks and len(obs["stack_states"][idx]) < max_height:
|
| 139 |
+
return idx
|
| 140 |
+
except Exception as e:
|
| 141 |
+
print(f" [LLM fallback: {e}]", file=sys.stderr)
|
| 142 |
+
|
| 143 |
+
return greedy_decide(obs)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
async def run_episode(url: str, difficulty: str = "medium", use_llm: bool = False) -> float:
|
| 147 |
+
ws_url = url.replace("http://", "ws://").replace("https://", "wss://")
|
| 148 |
+
if not ws_url.endswith("/ws"):
|
| 149 |
+
ws_url = ws_url.rstrip("/") + "/ws"
|
| 150 |
+
|
| 151 |
+
# [START] log
|
| 152 |
+
print(json.dumps({"type": "[START]", "task": difficulty, "difficulty": difficulty,
|
| 153 |
+
"env_url": url, "model": MODEL_NAME if use_llm else "greedy"}))
|
| 154 |
+
sys.stdout.flush()
|
| 155 |
+
|
| 156 |
+
total_reward = 0.0
|
| 157 |
+
step = 0
|
| 158 |
+
|
| 159 |
+
async with websockets.connect(ws_url) as ws:
|
| 160 |
+
await ws.send(json.dumps({"type": "reset", "difficulty": difficulty}))
|
| 161 |
+
resp = json.loads(await ws.recv())
|
| 162 |
+
obs = resp["observation"]
|
| 163 |
+
|
| 164 |
+
while not obs.get("done", False):
|
| 165 |
+
action_idx = llm_decide(obs) if use_llm else greedy_decide(obs)
|
| 166 |
+
|
| 167 |
+
await ws.send(json.dumps({"type": "step", "action": {"stack_index": action_idx}}))
|
| 168 |
+
resp = json.loads(await ws.recv())
|
| 169 |
+
obs = resp["observation"]
|
| 170 |
+
reward = resp["reward"]
|
| 171 |
+
done = resp["done"]
|
| 172 |
+
total_reward += reward
|
| 173 |
+
step += 1
|
| 174 |
+
|
| 175 |
+
# [STEP] log
|
| 176 |
+
print(json.dumps({
|
| 177 |
+
"type": "[STEP]",
|
| 178 |
+
"step": step,
|
| 179 |
+
"action": action_idx,
|
| 180 |
+
"reward": round(reward, 4),
|
| 181 |
+
"total_reward": round(total_reward, 4),
|
| 182 |
+
"done": done,
|
| 183 |
+
"rehandle_count": obs["rehandle_count"],
|
| 184 |
+
}))
|
| 185 |
+
sys.stdout.flush()
|
| 186 |
+
|
| 187 |
+
# fetch final state for score
|
| 188 |
+
await ws.send(json.dumps({"type": "state"}))
|
| 189 |
+
state_resp = json.loads(await ws.recv())
|
| 190 |
+
state = state_resp["state"]
|
| 191 |
+
|
| 192 |
+
final_score = state.get("score", 0.0)
|
| 193 |
+
|
| 194 |
+
# [END] log
|
| 195 |
+
print(json.dumps({
|
| 196 |
+
"type": "[END]",
|
| 197 |
+
"task": difficulty,
|
| 198 |
+
"difficulty": difficulty,
|
| 199 |
+
"total_reward": round(total_reward, 4),
|
| 200 |
+
"final_score": final_score,
|
| 201 |
+
"total_steps": step,
|
| 202 |
+
"rehandle_count": state.get("rehandle_count", 0),
|
| 203 |
+
}))
|
| 204 |
+
sys.stdout.flush()
|
| 205 |
+
|
| 206 |
+
return final_score
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
async def run_all(url: str, use_llm: bool = False):
|
| 210 |
+
for diff in ["easy", "medium", "hard"]:
|
| 211 |
+
await run_episode(url, difficulty=diff, use_llm=use_llm)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
if __name__ == "__main__":
|
| 215 |
+
parser = argparse.ArgumentParser(description="Container Port Baseline Agent")
|
| 216 |
+
parser.add_argument("--url", default=ENV_URL)
|
| 217 |
+
parser.add_argument("--difficulty", default="all", choices=["easy", "medium", "hard", "all"])
|
| 218 |
+
parser.add_argument("--use-llm", action="store_true")
|
| 219 |
+
args = parser.parse_args()
|
| 220 |
+
|
| 221 |
+
if args.difficulty == "all":
|
| 222 |
+
asyncio.run(run_all(args.url, use_llm=args.use_llm))
|
| 223 |
+
else:
|
| 224 |
+
asyncio.run(run_episode(args.url, difficulty=args.difficulty, use_llm=args.use_llm))
|
openenv.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: container-port-env
|
| 2 |
+
version: "0.1.0"
|
| 3 |
+
description: >
|
| 4 |
+
Container terminal yard RL environment. An agent places incoming ship
|
| 5 |
+
containers into stacks of limited height to minimize costly rehandle
|
| 6 |
+
operations during retrieval. Features 3 difficulty levels (easy/medium/hard)
|
| 7 |
+
with different stack configurations, retrieval frequencies, and lookahead
|
| 8 |
+
visibility. Models real port logistics decision-making.
|
| 9 |
+
tags:
|
| 10 |
+
- logistics
|
| 11 |
+
- planning
|
| 12 |
+
- real-world
|
| 13 |
+
- combinatorial-optimization
|
| 14 |
+
sdk: docker
|
| 15 |
+
entry_point: server.server:app
|
| 16 |
+
tools:
|
| 17 |
+
- name: place_container
|
| 18 |
+
description: >
|
| 19 |
+
Place the current incoming container into a specified stack index.
|
| 20 |
+
Priority 1=urgent (retrieved first), 2=normal, 3=low (retrieved last).
|
| 21 |
+
Burying high-priority under low-priority causes rehandle costs.
|
| 22 |
+
input_schema:
|
| 23 |
+
type: object
|
| 24 |
+
properties:
|
| 25 |
+
stack_index:
|
| 26 |
+
type: integer
|
| 27 |
+
description: "Zero-indexed stack to place the container into"
|
| 28 |
+
minimum: 0
|
| 29 |
+
required:
|
| 30 |
+
- stack_index
|
pyproject.toml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "openenv-container-port"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Container yard RL environment for OpenEnv hackathon"
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"fastapi>=0.110.0",
|
| 12 |
+
"uvicorn[standard]>=0.29.0",
|
| 13 |
+
"websockets>=12.0",
|
| 14 |
+
"pydantic>=2.0.0",
|
| 15 |
+
"openenv-core>=0.1.0",
|
| 16 |
+
"openai>=1.0.0",
|
| 17 |
+
]
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.110.0
|
| 2 |
+
uvicorn[standard]>=0.29.0
|
| 3 |
+
websockets>=12.0
|
| 4 |
+
pydantic>=2.0.0
|
| 5 |
+
openenv-core>=0.1.0
|
| 6 |
+
openai>=1.0.0
|
| 7 |
+
pytest>=8.0.0
|
| 8 |
+
huggingface_hub>=0.20.0
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
server/__pycache__/__init__.cpython-313.pyc.1828594663216
ADDED
|
Binary file (145 Bytes). View file
|
|
|
server/__pycache__/__init__.cpython-313.pyc.2347090562224
ADDED
|
Binary file (145 Bytes). View file
|
|
|
server/__pycache__/environment.cpython-313.pyc.1828594650032
ADDED
|
Binary file (11.4 kB). View file
|
|
|
server/__pycache__/environment.cpython-313.pyc.2347106154928
ADDED
|
Binary file (11.4 kB). View file
|
|
|
server/__pycache__/models.cpython-313.pyc.2347106454640
ADDED
|
Binary file (2.34 kB). View file
|
|
|
server/__pycache__/server.cpython-313.pyc.2347090514912
ADDED
|
Binary file (3.88 kB). View file
|
|
|
server/environment.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import List, Optional, Dict, Any, Tuple
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class Container:
|
| 7 |
+
id: str
|
| 8 |
+
priority: int # 1=urgent, 2=normal, 3=low
|
| 9 |
+
weight: float
|
| 10 |
+
|
| 11 |
+
DIFFICULTY_CONFIG = {
|
| 12 |
+
"easy": {
|
| 13 |
+
"n_stacks": 6,
|
| 14 |
+
"max_height": 4,
|
| 15 |
+
"n_containers": 20,
|
| 16 |
+
"retrieval_interval": 5,
|
| 17 |
+
"lookahead": 5,
|
| 18 |
+
"priority_weights": [0.4, 0.4, 0.2],
|
| 19 |
+
},
|
| 20 |
+
"medium": {
|
| 21 |
+
"n_stacks": 8,
|
| 22 |
+
"max_height": 5,
|
| 23 |
+
"n_containers": 35,
|
| 24 |
+
"retrieval_interval": 5,
|
| 25 |
+
"lookahead": 3,
|
| 26 |
+
"priority_weights": [0.33, 0.34, 0.33],
|
| 27 |
+
},
|
| 28 |
+
"hard": {
|
| 29 |
+
"n_stacks": 10,
|
| 30 |
+
"max_height": 6,
|
| 31 |
+
"n_containers": 50,
|
| 32 |
+
"retrieval_interval": 4,
|
| 33 |
+
"lookahead": 0,
|
| 34 |
+
"priority_weights": [0.25, 0.35, 0.40],
|
| 35 |
+
},
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
class ContainerYardEnv:
|
| 39 |
+
def __init__(self, difficulty: str = "medium", seed: Optional[int] = None):
|
| 40 |
+
assert difficulty in DIFFICULTY_CONFIG, f"difficulty must be one of {list(DIFFICULTY_CONFIG.keys())}"
|
| 41 |
+
self.difficulty = difficulty
|
| 42 |
+
self.seed = seed
|
| 43 |
+
cfg = DIFFICULTY_CONFIG[difficulty]
|
| 44 |
+
self.n_stacks = cfg["n_stacks"]
|
| 45 |
+
self.max_height = cfg["max_height"]
|
| 46 |
+
self.n_containers = cfg["n_containers"]
|
| 47 |
+
self.retrieval_interval = cfg["retrieval_interval"]
|
| 48 |
+
self.lookahead = cfg["lookahead"]
|
| 49 |
+
self.priority_weights = cfg["priority_weights"]
|
| 50 |
+
self.reset()
|
| 51 |
+
|
| 52 |
+
def reset(self) -> Dict[str, Any]:
|
| 53 |
+
if self.seed is not None:
|
| 54 |
+
random.seed(self.seed)
|
| 55 |
+
self.stacks: List[List[Container]] = [[] for _ in range(self.n_stacks)]
|
| 56 |
+
self.rehandle_count = 0
|
| 57 |
+
self.step_count = 0
|
| 58 |
+
self.total_reward = 0.0
|
| 59 |
+
self.done = False
|
| 60 |
+
self.manifest: List[Container] = self._generate_manifest()
|
| 61 |
+
self.retrieval_queue: List[str] = self._generate_retrieval_queue()
|
| 62 |
+
self.retrieval_pointer = 0
|
| 63 |
+
self.current_idx = 0
|
| 64 |
+
return self._observe(last_reward=0.0)
|
| 65 |
+
|
| 66 |
+
def _generate_manifest(self) -> List[Container]:
|
| 67 |
+
containers = []
|
| 68 |
+
for i in range(self.n_containers):
|
| 69 |
+
priority = random.choices([1, 2, 3], weights=self.priority_weights)[0]
|
| 70 |
+
containers.append(Container(
|
| 71 |
+
id=f"C{i:03d}",
|
| 72 |
+
priority=priority,
|
| 73 |
+
weight=round(random.uniform(5.0, 30.0), 1)
|
| 74 |
+
))
|
| 75 |
+
return containers
|
| 76 |
+
|
| 77 |
+
def _generate_retrieval_queue(self) -> List[str]:
|
| 78 |
+
ids_by_priority = {1: [], 2: [], 3: []}
|
| 79 |
+
for c in self.manifest:
|
| 80 |
+
ids_by_priority[c.priority].append(c.id)
|
| 81 |
+
for p in ids_by_priority:
|
| 82 |
+
random.shuffle(ids_by_priority[p])
|
| 83 |
+
queue = ids_by_priority[1] + ids_by_priority[2] + ids_by_priority[3]
|
| 84 |
+
return queue
|
| 85 |
+
|
| 86 |
+
def step(self, stack_index: int) -> Tuple[Dict[str, Any], float, bool, Dict[str, Any]]:
|
| 87 |
+
if self.done:
|
| 88 |
+
return self._observe(0.0), 0.0, True, {"error": "episode already done"}
|
| 89 |
+
|
| 90 |
+
if stack_index < 0 or stack_index >= self.n_stacks:
|
| 91 |
+
reward = -2.0
|
| 92 |
+
self.total_reward += reward
|
| 93 |
+
return self._observe(reward), reward, False, {"error": f"invalid stack_index {stack_index}, must be 0-{self.n_stacks-1}"}
|
| 94 |
+
|
| 95 |
+
if len(self.stacks[stack_index]) >= self.max_height:
|
| 96 |
+
reward = -2.0
|
| 97 |
+
self.total_reward += reward
|
| 98 |
+
return self._observe(reward), reward, False, {"error": f"stack {stack_index} is full (height {self.max_height})"}
|
| 99 |
+
|
| 100 |
+
current = self.manifest[self.current_idx]
|
| 101 |
+
self.stacks[stack_index].append(current)
|
| 102 |
+
placement_reward = self._placement_reward(stack_index, current)
|
| 103 |
+
|
| 104 |
+
self.current_idx += 1
|
| 105 |
+
self.step_count += 1
|
| 106 |
+
|
| 107 |
+
retrieval_cost = 0.0
|
| 108 |
+
retrievals_done = []
|
| 109 |
+
if self.step_count % self.retrieval_interval == 0:
|
| 110 |
+
cost, done_ids = self._trigger_retrieval()
|
| 111 |
+
retrieval_cost = cost
|
| 112 |
+
retrievals_done = done_ids
|
| 113 |
+
|
| 114 |
+
reward = placement_reward - retrieval_cost
|
| 115 |
+
self.total_reward += reward
|
| 116 |
+
self.done = (self.current_idx >= len(self.manifest))
|
| 117 |
+
|
| 118 |
+
return self._observe(reward), reward, self.done, {
|
| 119 |
+
"rehandles": self.rehandle_count,
|
| 120 |
+
"step": self.step_count,
|
| 121 |
+
"placement_reward": round(placement_reward, 4),
|
| 122 |
+
"retrieval_cost": round(retrieval_cost, 4),
|
| 123 |
+
"retrievals_done": retrievals_done,
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
def _placement_reward(self, stack_index: int, container: Container) -> float:
|
| 127 |
+
# stack_depth = zero-based index of the just-placed container
|
| 128 |
+
stack_depth = len(self.stacks[stack_index]) - 1
|
| 129 |
+
accessibility = (self.max_height - stack_depth) / self.max_height
|
| 130 |
+
priority_weight = (4 - container.priority) / 3.0 # priority 11.0, 20.67, 30.33
|
| 131 |
+
|
| 132 |
+
base = 0.3 * accessibility * priority_weight
|
| 133 |
+
|
| 134 |
+
# Bonus: high-priority container placed near top (accessible for fast retrieval)
|
| 135 |
+
if container.priority == 1 and stack_depth <= 1:
|
| 136 |
+
base += 0.15
|
| 137 |
+
|
| 138 |
+
# Penalty: placing lower-priority on top of higher-priority container (causes future rehandles)
|
| 139 |
+
if stack_depth > 0:
|
| 140 |
+
top_container = self.stacks[stack_index][-2] # container directly below
|
| 141 |
+
if container.priority > top_container.priority:
|
| 142 |
+
base -= 0.2 * (container.priority - top_container.priority) / 2.0
|
| 143 |
+
|
| 144 |
+
return round(base, 4)
|
| 145 |
+
|
| 146 |
+
def _trigger_retrieval(self) -> Tuple[float, List[str]]:
|
| 147 |
+
total_cost = 0.0
|
| 148 |
+
done_ids = []
|
| 149 |
+
for _ in range(2):
|
| 150 |
+
if self.retrieval_pointer >= len(self.retrieval_queue):
|
| 151 |
+
break
|
| 152 |
+
target_id = self.retrieval_queue[self.retrieval_pointer]
|
| 153 |
+
self.retrieval_pointer += 1
|
| 154 |
+
cost = self._retrieve(target_id)
|
| 155 |
+
total_cost += cost
|
| 156 |
+
done_ids.append(target_id)
|
| 157 |
+
return total_cost, done_ids
|
| 158 |
+
|
| 159 |
+
def _retrieve(self, target_id: str) -> float:
|
| 160 |
+
for stack in self.stacks:
|
| 161 |
+
for i, c in enumerate(stack):
|
| 162 |
+
if c.id == target_id:
|
| 163 |
+
rehandles = len(stack) - 1 - i # containers above target
|
| 164 |
+
self.rehandle_count += rehandles
|
| 165 |
+
stack.pop(i)
|
| 166 |
+
return round(rehandles * 0.4, 4)
|
| 167 |
+
return 0.0 # container not yet in yard — no penalty
|
| 168 |
+
|
| 169 |
+
def _get_upcoming_retrievals(self) -> List[str]:
|
| 170 |
+
start = self.retrieval_pointer
|
| 171 |
+
end = min(start + self.lookahead, len(self.retrieval_queue))
|
| 172 |
+
return self.retrieval_queue[start:end]
|
| 173 |
+
|
| 174 |
+
def _observe(self, last_reward: float = 0.0) -> Dict[str, Any]:
|
| 175 |
+
stack_states = []
|
| 176 |
+
for s in self.stacks:
|
| 177 |
+
stack_states.append([{"id": c.id, "priority": c.priority} for c in s])
|
| 178 |
+
|
| 179 |
+
current = None
|
| 180 |
+
if self.current_idx < len(self.manifest):
|
| 181 |
+
c = self.manifest[self.current_idx]
|
| 182 |
+
current = {"id": c.id, "priority": c.priority, "weight": c.weight}
|
| 183 |
+
|
| 184 |
+
return {
|
| 185 |
+
"stack_states": stack_states,
|
| 186 |
+
"current_container": current,
|
| 187 |
+
"upcoming_retrievals": self._get_upcoming_retrievals(),
|
| 188 |
+
"rehandle_count": self.rehandle_count,
|
| 189 |
+
"step": self.step_count,
|
| 190 |
+
"containers_remaining": len(self.manifest) - self.current_idx,
|
| 191 |
+
"n_stacks": self.n_stacks,
|
| 192 |
+
"max_height": self.max_height,
|
| 193 |
+
"difficulty": self.difficulty,
|
| 194 |
+
"last_reward": last_reward,
|
| 195 |
+
"done": self.done,
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
def get_state(self) -> Dict[str, Any]:
|
| 199 |
+
obs = self._observe()
|
| 200 |
+
obs["score"] = self.score()
|
| 201 |
+
obs["total_reward"] = round(self.total_reward, 4)
|
| 202 |
+
return obs
|
| 203 |
+
|
| 204 |
+
def score(self) -> float:
|
| 205 |
+
"""Normalized score in [0.0, 1.0]. Based on actual retrievals attempted."""
|
| 206 |
+
n_retrieved = self.retrieval_pointer # only count retrievals that actually happened
|
| 207 |
+
worst_case = n_retrieved * (self.max_height - 1)
|
| 208 |
+
if worst_case == 0:
|
| 209 |
+
return 1.0
|
| 210 |
+
score = max(0.0, 1.0 - self.rehandle_count / worst_case)
|
| 211 |
+
return round(min(score, 1.0), 4)
|
server/models.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import List, Optional, Dict, Any
|
| 3 |
+
|
| 4 |
+
class ContainerInfo(BaseModel):
|
| 5 |
+
id: str
|
| 6 |
+
priority: int = Field(..., ge=1, le=3)
|
| 7 |
+
weight: float
|
| 8 |
+
|
| 9 |
+
class StackEntry(BaseModel):
|
| 10 |
+
id: str
|
| 11 |
+
priority: int
|
| 12 |
+
|
| 13 |
+
class ContainerAction(BaseModel):
|
| 14 |
+
stack_index: int = Field(..., description="Which stack (0-indexed) to place the current container into")
|
| 15 |
+
|
| 16 |
+
class ContainerObservation(BaseModel):
|
| 17 |
+
stack_states: List[List[Dict[str, Any]]]
|
| 18 |
+
current_container: Optional[Dict[str, Any]]
|
| 19 |
+
upcoming_retrievals: List[str]
|
| 20 |
+
rehandle_count: int
|
| 21 |
+
step: int
|
| 22 |
+
containers_remaining: int
|
| 23 |
+
n_stacks: int
|
| 24 |
+
max_height: int
|
| 25 |
+
difficulty: str
|
| 26 |
+
last_reward: float
|
| 27 |
+
done: bool
|
| 28 |
+
|
| 29 |
+
class ContainerState(BaseModel):
|
| 30 |
+
stack_states: List[List[Dict[str, Any]]]
|
| 31 |
+
rehandle_count: int
|
| 32 |
+
step: int
|
| 33 |
+
score: float
|
| 34 |
+
difficulty: str
|
| 35 |
+
done: bool
|
| 36 |
+
total_reward: float
|
server/server.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import uuid
|
| 3 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 4 |
+
from server.environment import ContainerYardEnv
|
| 5 |
+
from server.models import ContainerAction
|
| 6 |
+
|
| 7 |
+
app = FastAPI(title="Container Port OpenEnv", version="0.1.0")
|
| 8 |
+
|
| 9 |
+
sessions: dict = {}
|
| 10 |
+
|
| 11 |
+
@app.get("/ping")
|
| 12 |
+
def ping():
|
| 13 |
+
return {"status": "ok", "env": "container-port-env"}
|
| 14 |
+
|
| 15 |
+
@app.get("/health")
|
| 16 |
+
def health():
|
| 17 |
+
return {
|
| 18 |
+
"status": "healthy",
|
| 19 |
+
"active_sessions": len(sessions),
|
| 20 |
+
"difficulties": ["easy", "medium", "hard"],
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
@app.websocket("/ws")
|
| 24 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 25 |
+
await websocket.accept()
|
| 26 |
+
session_id = str(uuid.uuid4())
|
| 27 |
+
sessions[session_id] = ContainerYardEnv(difficulty="medium")
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
while True:
|
| 31 |
+
raw = await websocket.receive_text()
|
| 32 |
+
msg = json.loads(raw)
|
| 33 |
+
msg_type = msg.get("type")
|
| 34 |
+
env = sessions[session_id]
|
| 35 |
+
|
| 36 |
+
if msg_type == "reset":
|
| 37 |
+
difficulty = msg.get("difficulty", "medium")
|
| 38 |
+
if difficulty not in ["easy", "medium", "hard"]:
|
| 39 |
+
difficulty = "medium"
|
| 40 |
+
sessions[session_id] = ContainerYardEnv(difficulty=difficulty)
|
| 41 |
+
env = sessions[session_id]
|
| 42 |
+
obs = env.reset()
|
| 43 |
+
await websocket.send_text(json.dumps({
|
| 44 |
+
"type": "reset",
|
| 45 |
+
"observation": obs,
|
| 46 |
+
"reward": 0.0,
|
| 47 |
+
"done": False,
|
| 48 |
+
"session_id": session_id,
|
| 49 |
+
}))
|
| 50 |
+
|
| 51 |
+
elif msg_type == "step":
|
| 52 |
+
try:
|
| 53 |
+
action = ContainerAction(**msg["action"])
|
| 54 |
+
obs, reward, done, info = env.step(action.stack_index)
|
| 55 |
+
await websocket.send_text(json.dumps({
|
| 56 |
+
"type": "step",
|
| 57 |
+
"observation": obs,
|
| 58 |
+
"reward": reward,
|
| 59 |
+
"done": done,
|
| 60 |
+
"info": info,
|
| 61 |
+
}))
|
| 62 |
+
except Exception as e:
|
| 63 |
+
await websocket.send_text(json.dumps({
|
| 64 |
+
"type": "error",
|
| 65 |
+
"message": str(e),
|
| 66 |
+
}))
|
| 67 |
+
|
| 68 |
+
elif msg_type == "state":
|
| 69 |
+
state = env.get_state()
|
| 70 |
+
await websocket.send_text(json.dumps({
|
| 71 |
+
"type": "state",
|
| 72 |
+
"state": state,
|
| 73 |
+
}))
|
| 74 |
+
|
| 75 |
+
else:
|
| 76 |
+
await websocket.send_text(json.dumps({
|
| 77 |
+
"type": "error",
|
| 78 |
+
"message": f"Unknown message type: {msg_type}",
|
| 79 |
+
}))
|
| 80 |
+
|
| 81 |
+
except WebSocketDisconnect:
|
| 82 |
+
pass
|
| 83 |
+
finally:
|
| 84 |
+
sessions.pop(session_id, None)
|
tests/__pycache__/test_env.cpython-313-pytest-9.0.2.pyc.10200
ADDED
|
Binary file (18.9 kB). View file
|
|
|
tests/__pycache__/test_env.cpython-313-pytest-9.0.2.pyc.16772
ADDED
|
Binary file (18.9 kB). View file
|
|
|
tests/test_env.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from server.environment import ContainerYardEnv, DIFFICULTY_CONFIG
|
| 3 |
+
|
| 4 |
+
@pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
|
| 5 |
+
def test_reset_returns_valid_obs(difficulty):
|
| 6 |
+
env = ContainerYardEnv(difficulty=difficulty, seed=42)
|
| 7 |
+
obs = env.reset()
|
| 8 |
+
cfg = DIFFICULTY_CONFIG[difficulty]
|
| 9 |
+
assert len(obs["stack_states"]) == cfg["n_stacks"]
|
| 10 |
+
assert obs["current_container"] is not None
|
| 11 |
+
assert obs["step"] == 0
|
| 12 |
+
assert obs["rehandle_count"] == 0
|
| 13 |
+
assert obs["difficulty"] == difficulty
|
| 14 |
+
assert obs["done"] == False
|
| 15 |
+
|
| 16 |
+
@pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
|
| 17 |
+
def test_step_valid_action(difficulty):
|
| 18 |
+
env = ContainerYardEnv(difficulty=difficulty, seed=42)
|
| 19 |
+
env.reset()
|
| 20 |
+
obs, reward, done, info = env.step(0)
|
| 21 |
+
assert isinstance(reward, float)
|
| 22 |
+
assert obs["step"] == 1
|
| 23 |
+
assert len(obs["stack_states"][0]) == 1
|
| 24 |
+
assert "rehandles" in info
|
| 25 |
+
|
| 26 |
+
@pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
|
| 27 |
+
def test_step_invalid_stack_index(difficulty):
|
| 28 |
+
env = ContainerYardEnv(difficulty=difficulty, seed=42)
|
| 29 |
+
env.reset()
|
| 30 |
+
obs, reward, done, info = env.step(999)
|
| 31 |
+
assert reward == -2.0
|
| 32 |
+
assert "error" in info
|
| 33 |
+
assert done == False
|
| 34 |
+
|
| 35 |
+
@pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
|
| 36 |
+
def test_full_episode_completes(difficulty):
|
| 37 |
+
env = ContainerYardEnv(difficulty=difficulty, seed=42)
|
| 38 |
+
env.reset()
|
| 39 |
+
done = False
|
| 40 |
+
steps = 0
|
| 41 |
+
cfg = DIFFICULTY_CONFIG[difficulty]
|
| 42 |
+
n_stacks = cfg["n_stacks"]
|
| 43 |
+
max_height = cfg["max_height"]
|
| 44 |
+
while not done:
|
| 45 |
+
stacks = env._observe()["stack_states"]
|
| 46 |
+
chosen = 0
|
| 47 |
+
for i in range(n_stacks):
|
| 48 |
+
if len(stacks[i]) < max_height:
|
| 49 |
+
chosen = i
|
| 50 |
+
break
|
| 51 |
+
_, _, done, _ = env.step(chosen)
|
| 52 |
+
steps += 1
|
| 53 |
+
assert steps < 1000, "Episode did not complete in time"
|
| 54 |
+
assert done
|
| 55 |
+
|
| 56 |
+
@pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
|
| 57 |
+
def test_score_in_range(difficulty):
|
| 58 |
+
env = ContainerYardEnv(difficulty=difficulty, seed=42)
|
| 59 |
+
env.reset()
|
| 60 |
+
done = False
|
| 61 |
+
cfg = DIFFICULTY_CONFIG[difficulty]
|
| 62 |
+
n_stacks = cfg["n_stacks"]
|
| 63 |
+
max_height = cfg["max_height"]
|
| 64 |
+
while not done:
|
| 65 |
+
stacks = env._observe()["stack_states"]
|
| 66 |
+
chosen = 0
|
| 67 |
+
for i in range(n_stacks):
|
| 68 |
+
if len(stacks[i]) < max_height:
|
| 69 |
+
chosen = i
|
| 70 |
+
break
|
| 71 |
+
_, _, done, _ = env.step(chosen)
|
| 72 |
+
score = env.score()
|
| 73 |
+
assert 0.0 <= score <= 1.0
|
| 74 |
+
|
| 75 |
+
def test_lookahead_visibility():
|
| 76 |
+
easy_env = ContainerYardEnv(difficulty="easy", seed=42)
|
| 77 |
+
hard_env = ContainerYardEnv(difficulty="hard", seed=42)
|
| 78 |
+
easy_obs = easy_env.reset()
|
| 79 |
+
hard_obs = hard_env.reset()
|
| 80 |
+
assert len(easy_obs["upcoming_retrievals"]) > len(hard_obs["upcoming_retrievals"])
|
| 81 |
+
assert len(hard_obs["upcoming_retrievals"]) == 0
|
| 82 |
+
|
| 83 |
+
def test_reward_is_dense():
|
| 84 |
+
env = ContainerYardEnv(difficulty="medium", seed=42)
|
| 85 |
+
env.reset()
|
| 86 |
+
rewards = []
|
| 87 |
+
done = False
|
| 88 |
+
step = 0
|
| 89 |
+
while not done and step < 20:
|
| 90 |
+
stacks = env._observe()["stack_states"]
|
| 91 |
+
chosen = step % 8
|
| 92 |
+
if len(stacks[chosen]) >= 5:
|
| 93 |
+
chosen = 0
|
| 94 |
+
_, r, done, _ = env.step(chosen)
|
| 95 |
+
rewards.append(r)
|
| 96 |
+
step += 1
|
| 97 |
+
nonzero = sum(1 for r in rewards if abs(r) > 1e-6)
|
| 98 |
+
assert nonzero >= len(rewards) * 0.5, f"Too many zero rewards: {rewards}"
|
| 99 |
+
|
| 100 |
+
def test_no_double_retrieval():
|
| 101 |
+
"""Retrieval pointer advances correctly — no container retrieved twice."""
|
| 102 |
+
env = ContainerYardEnv(difficulty="easy", seed=42)
|
| 103 |
+
env.reset()
|
| 104 |
+
seen_ids = set()
|
| 105 |
+
for _ in range(env.n_containers):
|
| 106 |
+
if env.done:
|
| 107 |
+
break
|
| 108 |
+
env.step(0 if len(env.stacks[0]) < env.max_height else 1)
|
| 109 |
+
# retrieval_pointer should be <= queue length
|
| 110 |
+
assert env.retrieval_pointer <= len(env.retrieval_queue)
|