Spaces:
Sleeping
Sleeping
Shabista Sehar commited on
Commit ·
0eb4f6f
1
Parent(s): 5a30ea5
construtcion env
Browse files- .env.example +7 -0
- .gitignore +15 -0
- Dockerfile +7 -4
- README.md +49 -42
- client/__init__.py +4 -0
- client/container_env.py +4 -34
- inference.py +164 -114
- models.py +40 -0
- openenv.yaml +8 -21
- pyproject.toml +12 -3
- pytest.ini +5 -0
- requirements.txt +3 -3
- server/app.py +36 -0
- server/environment.py +102 -59
- server/models.py +0 -36
- server/server.py +0 -84
- tests/conftest.py +9 -0
- tests/test_env.py +0 -110
- tests/test_openenv_env.py +170 -0
- uv.lock +0 -0
.env.example
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Container Port OpenEnv — Environment Variables
|
| 2 |
+
# Copy this file and fill in your actual values before running inference
|
| 3 |
+
HF_TOKEN=your_huggingface_token_here
|
| 4 |
+
API_BASE_URL=https://router.huggingface.co/v1
|
| 5 |
+
MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct
|
| 6 |
+
ENV_URL=http://localhost:7860
|
| 7 |
+
LOCAL_IMAGE_NAME=
|
.gitignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
.env
|
| 5 |
+
.venv/
|
| 6 |
+
.uv-cache/
|
| 7 |
+
*.egg-info/
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
| 10 |
+
outputs/
|
| 11 |
+
client/__pycache__/
|
| 12 |
+
server/__pycache__/
|
| 13 |
+
tests/__pycache__/
|
| 14 |
+
.pytest_cache/
|
| 15 |
+
pytest-cache-files-*/
|
Dockerfile
CHANGED
|
@@ -2,13 +2,16 @@ FROM python:3.11-slim
|
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
-
COPY requirements.txt .
|
| 6 |
-
RUN pip install --no-cache-dir -r requirements.txt
|
| 7 |
-
|
| 8 |
COPY . .
|
| 9 |
|
|
|
|
|
|
|
| 10 |
ENV PYTHONPATH=/app
|
|
|
|
| 11 |
|
| 12 |
EXPOSE 7860
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
COPY . .
|
| 6 |
|
| 7 |
+
RUN pip install --no-cache-dir -e .
|
| 8 |
+
|
| 9 |
ENV PYTHONPATH=/app
|
| 10 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 11 |
|
| 12 |
EXPOSE 7860
|
| 13 |
|
| 14 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=15s --retries=3 \
|
| 15 |
+
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/health')"
|
| 16 |
+
|
| 17 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,71 +1,78 @@
|
|
| 1 |
# Container Port Environment
|
| 2 |
|
| 3 |
-
An OpenEnv
|
| 4 |
|
| 5 |
## Task
|
| 6 |
|
| 7 |
-
|
| 8 |
-
stacks. At regular intervals, specific containers are retrieved. If a target is buried under others,
|
| 9 |
-
each container above it is a **rehandle** — expensive in real port operations.
|
| 10 |
|
| 11 |
-
|
| 12 |
|
| 13 |
## Difficulty Levels
|
| 14 |
|
| 15 |
-
| Parameter
|
| 16 |
-
|---
|
| 17 |
-
| Stacks
|
| 18 |
-
| Max
|
| 19 |
-
| Containers
|
| 20 |
-
| Retrieval interval |
|
| 21 |
-
| Lookahead
|
| 22 |
|
| 23 |
-
##
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
| Burying high-priority under low-priority | -0.10 to -0.20 |
|
| 30 |
-
| Invalid action (full stack / bad index) | -2.0 |
|
| 31 |
-
| Each rehandle at retrieval time | -0.40 |
|
| 32 |
|
| 33 |
-
|
| 34 |
|
| 35 |
-
|
| 36 |
|
| 37 |
-
## Setup
|
| 38 |
```bash
|
| 39 |
-
|
| 40 |
-
|
|
|
|
| 41 |
```
|
| 42 |
|
| 43 |
-
|
| 44 |
-
```bash
|
| 45 |
-
# Greedy agent, all difficulties
|
| 46 |
-
python inference.py --difficulty all
|
| 47 |
|
| 48 |
-
#
|
| 49 |
-
export HF_TOKEN=hf_your_token_here
|
| 50 |
-
python inference.py --use-llm --difficulty all
|
| 51 |
|
| 52 |
-
|
| 53 |
-
python inference.py --
|
|
|
|
|
|
|
| 54 |
```
|
| 55 |
|
|
|
|
|
|
|
| 56 |
## Docker
|
|
|
|
| 57 |
```bash
|
| 58 |
docker build -t container-port-env .
|
| 59 |
docker run -p 7860:7860 container-port-env
|
| 60 |
```
|
| 61 |
|
| 62 |
-
##
|
|
|
|
|
|
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
|
| 68 |
-
|
| 69 |
-
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Container Port Environment
|
| 2 |
|
| 3 |
+
An OpenEnv environment for container-yard stack planning at a shipping terminal.
|
| 4 |
|
| 5 |
## Task
|
| 6 |
|
| 7 |
+
Incoming containers have priority `1`, `2`, or `3`. The agent places each one into a bounded stack. During retrieval, every container sitting above the target counts as a rehandle and adds cost.
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
Goal: minimize total rehandles across the episode.
|
| 10 |
|
| 11 |
## Difficulty Levels
|
| 12 |
|
| 13 |
+
| Parameter | Easy | Medium | Hard |
|
| 14 |
+
|---|---|---|---|
|
| 15 |
+
| Stacks | 6 | 8 | 10 |
|
| 16 |
+
| Max height | 4 | 5 | 6 |
|
| 17 |
+
| Containers | 20 | 35 | 50 |
|
| 18 |
+
| Retrieval interval | 5 | 5 | 4 |
|
| 19 |
+
| Lookahead | 5 | 3 | 0 |
|
| 20 |
|
| 21 |
+
## Run Locally
|
| 22 |
|
| 23 |
+
```bash
|
| 24 |
+
pip install -e .
|
| 25 |
+
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 26 |
+
```
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
Web UI: `http://127.0.0.1:7860/web`
|
| 29 |
|
| 30 |
+
For manual stateful checks, use the web endpoints:
|
| 31 |
|
|
|
|
| 32 |
```bash
|
| 33 |
+
curl http://127.0.0.1:7860/health
|
| 34 |
+
curl -X POST http://127.0.0.1:7860/web/reset -H "Content-Type: application/json" -d "{\"difficulty\":\"easy\"}"
|
| 35 |
+
curl -X POST http://127.0.0.1:7860/web/step -H "Content-Type: application/json" -d "{\"action\":{\"stack_index\":0}}"
|
| 36 |
```
|
| 37 |
|
| 38 |
+
`/reset` and `/step` are stateless simulation endpoints in `openenv-core 0.2.3`. For browser-style interactive testing, use `/web`, `/web/reset`, `/web/step`, or the WebSocket flow used by `inference.py`.
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
## Run Inference
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
```bash
|
| 43 |
+
python inference.py --difficulty all
|
| 44 |
+
python inference.py --difficulty easy
|
| 45 |
+
python inference.py --url http://127.0.0.1:7860 --difficulty all
|
| 46 |
```
|
| 47 |
|
| 48 |
+
For LLM mode, set `HF_TOKEN` first.
|
| 49 |
+
|
| 50 |
## Docker
|
| 51 |
+
|
| 52 |
```bash
|
| 53 |
docker build -t container-port-env .
|
| 54 |
docker run -p 7860:7860 container-port-env
|
| 55 |
```
|
| 56 |
|
| 57 |
+
## Tests
|
| 58 |
+
|
| 59 |
+
Run the full test suite:
|
| 60 |
|
| 61 |
+
```bash
|
| 62 |
+
pytest tests/test_openenv_env.py -v
|
| 63 |
+
```
|
| 64 |
|
| 65 |
+
| Test | What it covers |
|
| 66 |
+
|---|---|
|
| 67 |
+
| test_reset_returns_valid_obs | Reset returns correct stack count, step=0, no rehandles |
|
| 68 |
+
| test_step_valid_action | Valid placement increments step and fills stack |
|
| 69 |
+
| test_step_invalid_action_penalized | Out-of-range stack index returns -2.0 reward |
|
| 70 |
+
| test_score_in_range | Full episode score stays in [0.0, 1.0] |
|
| 71 |
+
| test_full_episode_completes | All 3 difficulties reach done=True within 500 steps |
|
| 72 |
+
| test_lookahead_visibility | Easy shows more upcoming retrievals than hard (hard=0) |
|
| 73 |
+
| test_reward_is_dense | At least 50% of steps have non-zero reward |
|
| 74 |
+
| test_no_double_retrieval | retrieval_pointer never exceeds queue length |
|
| 75 |
+
| test_health_route | GET /health returns 200 |
|
| 76 |
+
| test_web_ui_route | GET /web returns 200 (Gradio UI) |
|
| 77 |
+
| test_http_reset_returns_observation | POST /reset returns valid easy-mode observation |
|
| 78 |
+
| test_http_reset_then_step_preserves_state | Sequential reset+step operates on same episode |
|
client/__init__.py
CHANGED
|
@@ -1 +1,5 @@
|
|
|
|
|
| 1 |
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from client.container_env import ContainerEnvClient
|
| 2 |
|
| 3 |
+
ContainerPortEnv = ContainerEnvClient
|
| 4 |
+
|
| 5 |
+
__all__ = ["ContainerEnvClient", "ContainerPortEnv"]
|
client/container_env.py
CHANGED
|
@@ -1,37 +1,7 @@
|
|
| 1 |
-
import
|
| 2 |
-
import websockets
|
| 3 |
-
from typing import Any, Dict, Tuple
|
| 4 |
|
| 5 |
-
class ContainerEnvClient:
|
| 6 |
-
"""Async client for Container Port OpenEnv."""
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
self.ws_url = ws_url.rstrip("/") + "/ws"
|
| 11 |
-
self._ws = None
|
| 12 |
|
| 13 |
-
|
| 14 |
-
self._ws = await websockets.connect(self.ws_url)
|
| 15 |
-
return self
|
| 16 |
-
|
| 17 |
-
async def __aexit__(self, *args):
|
| 18 |
-
if self._ws:
|
| 19 |
-
await self._ws.close()
|
| 20 |
-
|
| 21 |
-
async def reset(self, difficulty: str = "medium") -> Dict[str, Any]:
|
| 22 |
-
await self._ws.send(json.dumps({"type": "reset", "difficulty": difficulty}))
|
| 23 |
-
resp = json.loads(await self._ws.recv())
|
| 24 |
-
return resp["observation"]
|
| 25 |
-
|
| 26 |
-
async def step(self, stack_index: int) -> Tuple[Dict, float, bool, Dict]:
|
| 27 |
-
await self._ws.send(json.dumps({
|
| 28 |
-
"type": "step",
|
| 29 |
-
"action": {"stack_index": stack_index}
|
| 30 |
-
}))
|
| 31 |
-
resp = json.loads(await self._ws.recv())
|
| 32 |
-
return resp["observation"], resp["reward"], resp["done"], resp.get("info", {})
|
| 33 |
-
|
| 34 |
-
async def state(self) -> Dict[str, Any]:
|
| 35 |
-
await self._ws.send(json.dumps({"type": "state"}))
|
| 36 |
-
resp = json.loads(await self._ws.recv())
|
| 37 |
-
return resp["state"]
|
|
|
|
| 1 |
+
from openenv import GenericEnvClient
|
|
|
|
|
|
|
| 2 |
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
class ContainerEnvClient(GenericEnvClient):
|
| 5 |
+
"""OpenEnv client for Container Port."""
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference.py
CHANGED
|
@@ -1,50 +1,90 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
Container Port OpenEnv — Baseline Inference Script
|
| 4 |
-
SST x Meta PyTorch OpenEnv Hackathon
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
|
| 11 |
Usage:
|
| 12 |
python inference.py
|
| 13 |
-
python inference.py --url https://YOUR_USERNAME-container-port-env.hf.space --difficulty all
|
| 14 |
python inference.py --difficulty easy
|
|
|
|
|
|
|
|
|
|
| 15 |
"""
|
| 16 |
|
|
|
|
| 17 |
import os
|
| 18 |
import sys
|
| 19 |
import json
|
| 20 |
-
import
|
| 21 |
-
import argparse
|
| 22 |
-
import websockets
|
| 23 |
from openai import OpenAI
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# Required configuration variables
|
| 26 |
-
API_BASE_URL
|
| 27 |
-
MODEL_NAME
|
| 28 |
-
HF_TOKEN
|
|
|
|
|
|
|
| 29 |
#
|
| 30 |
|
| 31 |
-
ENV_URL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
def _llm_client() -> OpenAI:
|
| 35 |
-
"""Return an OpenAI-compatible client pointed at HF Inference Router."""
|
| 36 |
-
return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
def greedy_decide(obs: dict) -> int:
|
| 40 |
-
""
|
| 41 |
-
|
| 42 |
-
Scores each valid stack by accessibility and priority compatibility.
|
| 43 |
-
"""
|
| 44 |
-
stacks = obs["stack_states"]
|
| 45 |
-
current = obs.get("current_container")
|
| 46 |
max_height = obs["max_height"]
|
| 47 |
-
upcoming
|
| 48 |
|
| 49 |
if current is None:
|
| 50 |
return 0
|
|
@@ -56,16 +96,15 @@ def greedy_decide(obs: dict) -> int:
|
|
| 56 |
depth = len(stack)
|
| 57 |
if depth >= max_height:
|
| 58 |
continue
|
| 59 |
-
|
| 60 |
score = 0.0
|
| 61 |
accessibility = (max_height - depth) / max_height
|
| 62 |
score += accessibility * (4 - cur_priority)
|
| 63 |
|
| 64 |
if depth > 0:
|
| 65 |
-
|
| 66 |
-
if cur_priority >
|
| 67 |
-
score -= 10.0 * (cur_priority -
|
| 68 |
-
elif cur_priority <
|
| 69 |
score += 3.0
|
| 70 |
|
| 71 |
if current["id"] in upcoming:
|
|
@@ -82,53 +121,48 @@ def greedy_decide(obs: dict) -> int:
|
|
| 82 |
for i, stack in enumerate(stacks):
|
| 83 |
if len(stack) < max_height:
|
| 84 |
return i
|
| 85 |
-
return best_stack
|
| 86 |
|
| 87 |
|
| 88 |
-
def llm_decide(obs: dict) -> int:
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
n_stacks = obs["n_stacks"]
|
| 93 |
max_height = obs["max_height"]
|
| 94 |
-
upcoming
|
| 95 |
difficulty = obs.get("difficulty", "medium")
|
| 96 |
|
| 97 |
-
|
| 98 |
for i, stack in enumerate(stacks):
|
| 99 |
if not stack:
|
| 100 |
-
|
| 101 |
else:
|
| 102 |
contents = ", ".join(f"{c['id']}(p{c['priority']})" for c in stack)
|
| 103 |
-
|
| 104 |
f" Stack {i}: [{contents}] depth={len(stack)}/{max_height},"
|
| 105 |
f" top=priority-{stack[-1]['priority']}"
|
| 106 |
)
|
| 107 |
|
| 108 |
prompt = (
|
| 109 |
-
f"You are
|
| 110 |
-
f"
|
| 111 |
-
f"RULE:
|
| 112 |
-
f"Priority 1=URGENT (retrieved first), 2=Normal, 3=Low (retrieved last).\n\n"
|
| 113 |
f"DIFFICULTY: {difficulty}\n"
|
| 114 |
-
f"UPCOMING RETRIEVALS
|
| 115 |
-
f"{upcoming if upcoming else 'Unknown (hard mode)'}\n\n"
|
| 116 |
f"CONTAINER TO PLACE: id={current['id']}, priority={current['priority']}, "
|
| 117 |
f"weight={current['weight']}kg\n\n"
|
| 118 |
-
f"
|
| 119 |
-
f"
|
| 120 |
)
|
| 121 |
|
| 122 |
try:
|
| 123 |
-
|
| 124 |
-
response = client.chat.completions.create(
|
| 125 |
model=MODEL_NAME,
|
| 126 |
max_tokens=64,
|
| 127 |
temperature=0.0,
|
| 128 |
messages=[{"role": "user", "content": prompt}],
|
| 129 |
)
|
| 130 |
-
text =
|
| 131 |
-
# strip markdown fences if model wraps in ```json ... ```
|
| 132 |
if "```" in text:
|
| 133 |
text = text.split("```")[1]
|
| 134 |
if text.startswith("json"):
|
|
@@ -138,84 +172,100 @@ def llm_decide(obs: dict) -> int:
|
|
| 138 |
if 0 <= idx < n_stacks and len(obs["stack_states"][idx]) < max_height:
|
| 139 |
return idx
|
| 140 |
except Exception as e:
|
| 141 |
-
print(f"
|
| 142 |
|
| 143 |
return greedy_decide(obs)
|
| 144 |
|
| 145 |
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
ws_url = url.replace("http://", "ws://").replace("https://", "wss://")
|
| 148 |
if not ws_url.endswith("/ws"):
|
| 149 |
ws_url = ws_url.rstrip("/") + "/ws"
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
"
|
| 180 |
-
|
| 181 |
-
"
|
| 182 |
-
"done"
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
| 210 |
for diff in ["easy", "medium", "hard"]:
|
| 211 |
await run_episode(url, difficulty=diff, use_llm=use_llm)
|
| 212 |
|
| 213 |
|
|
|
|
|
|
|
| 214 |
if __name__ == "__main__":
|
|
|
|
|
|
|
| 215 |
parser = argparse.ArgumentParser(description="Container Port Baseline Agent")
|
| 216 |
parser.add_argument("--url", default=ENV_URL)
|
| 217 |
-
parser.add_argument("--difficulty", default="all",
|
| 218 |
-
|
|
|
|
|
|
|
| 219 |
args = parser.parse_args()
|
| 220 |
|
| 221 |
if args.difficulty == "all":
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
Container Port OpenEnv — Baseline Inference Script
|
| 4 |
+
SST x Meta PyTorch OpenEnv Hackathon 2026
|
| 5 |
|
| 6 |
+
Stdout format (grader parses these exactly):
|
| 7 |
+
[START] task=<task> env=container-port-env model=<model>
|
| 8 |
+
[STEP] step=<n> action=<stack_idx> reward=<0.00> done=<true|false> error=<msg|null>
|
| 9 |
+
[END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...>
|
| 10 |
|
| 11 |
Usage:
|
| 12 |
python inference.py
|
|
|
|
| 13 |
python inference.py --difficulty easy
|
| 14 |
+
python inference.py --difficulty all
|
| 15 |
+
python inference.py --use-llm
|
| 16 |
+
python inference.py --url https://YOUR_USERNAME-container-port-env.hf.space
|
| 17 |
"""
|
| 18 |
|
| 19 |
+
import asyncio
|
| 20 |
import os
|
| 21 |
import sys
|
| 22 |
import json
|
| 23 |
+
from typing import List, Optional
|
|
|
|
|
|
|
| 24 |
from openai import OpenAI
|
| 25 |
|
| 26 |
+
# Load .env file if present (before os.getenv calls)
|
| 27 |
+
def _load_dotenv():
|
| 28 |
+
env_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env")
|
| 29 |
+
if os.path.exists(env_path):
|
| 30 |
+
with open(env_path) as f:
|
| 31 |
+
for line in f:
|
| 32 |
+
line = line.strip()
|
| 33 |
+
if not line or line.startswith("#") or "=" not in line:
|
| 34 |
+
continue
|
| 35 |
+
key, _, value = line.partition("=")
|
| 36 |
+
key = key.strip()
|
| 37 |
+
value = value.strip().strip('"').strip("'")
|
| 38 |
+
if key and key not in os.environ:
|
| 39 |
+
os.environ[key] = value
|
| 40 |
+
|
| 41 |
+
_load_dotenv()
|
| 42 |
+
|
| 43 |
# Required configuration variables
|
| 44 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 45 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
|
| 46 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 47 |
+
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 48 |
+
API_KEY = HF_TOKEN or os.getenv("API_KEY")
|
| 49 |
#
|
| 50 |
|
| 51 |
+
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
|
| 52 |
+
TASK_NAME = "container-stacking"
|
| 53 |
+
BENCHMARK = "container-port-env"
|
| 54 |
+
MAX_STEPS = 200 # hard mode has 50 containers, safety ceiling
|
| 55 |
+
SUCCESS_SCORE_THRESHOLD = 0.5
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Logging helpers (exact SST format)
|
| 59 |
|
| 60 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 61 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 62 |
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 65 |
+
error_val = error if error else "null"
|
| 66 |
+
done_val = str(done).lower()
|
| 67 |
+
print(
|
| 68 |
+
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
|
| 69 |
+
flush=True,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 74 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 75 |
+
print(
|
| 76 |
+
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
|
| 77 |
+
flush=True,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# Agents
|
| 82 |
|
| 83 |
def greedy_decide(obs: dict) -> int:
|
| 84 |
+
stacks = obs["stack_states"]
|
| 85 |
+
current = obs.get("current_container")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
max_height = obs["max_height"]
|
| 87 |
+
upcoming = set(obs.get("upcoming_retrievals", []))
|
| 88 |
|
| 89 |
if current is None:
|
| 90 |
return 0
|
|
|
|
| 96 |
depth = len(stack)
|
| 97 |
if depth >= max_height:
|
| 98 |
continue
|
|
|
|
| 99 |
score = 0.0
|
| 100 |
accessibility = (max_height - depth) / max_height
|
| 101 |
score += accessibility * (4 - cur_priority)
|
| 102 |
|
| 103 |
if depth > 0:
|
| 104 |
+
top_p = stack[-1]["priority"]
|
| 105 |
+
if cur_priority > top_p:
|
| 106 |
+
score -= 10.0 * (cur_priority - top_p)
|
| 107 |
+
elif cur_priority < top_p:
|
| 108 |
score += 3.0
|
| 109 |
|
| 110 |
if current["id"] in upcoming:
|
|
|
|
| 121 |
for i, stack in enumerate(stacks):
|
| 122 |
if len(stack) < max_height:
|
| 123 |
return i
|
| 124 |
+
return max(best_stack, 0)
|
| 125 |
|
| 126 |
|
| 127 |
+
def llm_decide(obs: dict, client: OpenAI) -> int:
|
| 128 |
+
stacks = obs["stack_states"]
|
| 129 |
+
current = obs.get("current_container")
|
| 130 |
+
n_stacks = obs["n_stacks"]
|
|
|
|
| 131 |
max_height = obs["max_height"]
|
| 132 |
+
upcoming = obs.get("upcoming_retrievals", [])
|
| 133 |
difficulty = obs.get("difficulty", "medium")
|
| 134 |
|
| 135 |
+
lines = []
|
| 136 |
for i, stack in enumerate(stacks):
|
| 137 |
if not stack:
|
| 138 |
+
lines.append(f" Stack {i}: EMPTY (0/{max_height})")
|
| 139 |
else:
|
| 140 |
contents = ", ".join(f"{c['id']}(p{c['priority']})" for c in stack)
|
| 141 |
+
lines.append(
|
| 142 |
f" Stack {i}: [{contents}] depth={len(stack)}/{max_height},"
|
| 143 |
f" top=priority-{stack[-1]['priority']}"
|
| 144 |
)
|
| 145 |
|
| 146 |
prompt = (
|
| 147 |
+
f"You are a container yard planner. Minimize rehandle operations.\n"
|
| 148 |
+
f"Priority 1=URGENT (retrieved first), 2=Normal, 3=Low.\n"
|
| 149 |
+
f"RULE: containers above the target at retrieval = rehandles (costly).\n\n"
|
|
|
|
| 150 |
f"DIFFICULTY: {difficulty}\n"
|
| 151 |
+
f"UPCOMING RETRIEVALS: {upcoming or 'Unknown (hard mode)'}\n\n"
|
|
|
|
| 152 |
f"CONTAINER TO PLACE: id={current['id']}, priority={current['priority']}, "
|
| 153 |
f"weight={current['weight']}kg\n\n"
|
| 154 |
+
f"STACKS (bottom->top):\n" + "\n".join(lines) + "\n\n"
|
| 155 |
+
f"Reply ONLY with valid JSON: {{\"stack_index\": <int 0-{n_stacks-1}>}}"
|
| 156 |
)
|
| 157 |
|
| 158 |
try:
|
| 159 |
+
resp = client.chat.completions.create(
|
|
|
|
| 160 |
model=MODEL_NAME,
|
| 161 |
max_tokens=64,
|
| 162 |
temperature=0.0,
|
| 163 |
messages=[{"role": "user", "content": prompt}],
|
| 164 |
)
|
| 165 |
+
text = resp.choices[0].message.content.strip()
|
|
|
|
| 166 |
if "```" in text:
|
| 167 |
text = text.split("```")[1]
|
| 168 |
if text.startswith("json"):
|
|
|
|
| 172 |
if 0 <= idx < n_stacks and len(obs["stack_states"][idx]) < max_height:
|
| 173 |
return idx
|
| 174 |
except Exception as e:
|
| 175 |
+
print(f"[DEBUG] LLM fallback: {e}", file=sys.stderr, flush=True)
|
| 176 |
|
| 177 |
return greedy_decide(obs)
|
| 178 |
|
| 179 |
|
| 180 |
+
# Episode runner
|
| 181 |
+
|
| 182 |
+
async def run_episode(
|
| 183 |
+
url: str,
|
| 184 |
+
difficulty: str = "medium",
|
| 185 |
+
use_llm: bool = False,
|
| 186 |
+
) -> float:
|
| 187 |
+
import websockets
|
| 188 |
+
|
| 189 |
ws_url = url.replace("http://", "ws://").replace("https://", "wss://")
|
| 190 |
if not ws_url.endswith("/ws"):
|
| 191 |
ws_url = ws_url.rstrip("/") + "/ws"
|
| 192 |
|
| 193 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) if use_llm else None
|
| 194 |
+
model_label = MODEL_NAME if use_llm else "greedy"
|
| 195 |
+
|
| 196 |
+
log_start(task=f"{TASK_NAME}-{difficulty}", env=BENCHMARK, model=model_label)
|
| 197 |
+
|
| 198 |
+
rewards: List[float] = []
|
| 199 |
+
steps_taken = 0
|
| 200 |
+
score = 0.0
|
| 201 |
+
success = False
|
| 202 |
+
|
| 203 |
+
try:
|
| 204 |
+
async with websockets.connect(ws_url) as ws:
|
| 205 |
+
await ws.send(json.dumps({"type": "reset", "data": {"difficulty": difficulty}}))
|
| 206 |
+
resp = json.loads(await ws.recv())
|
| 207 |
+
payload = resp.get("data", {})
|
| 208 |
+
obs = payload.get("observation", payload)
|
| 209 |
+
|
| 210 |
+
for step in range(1, MAX_STEPS + 1):
|
| 211 |
+
if obs.get("done", False):
|
| 212 |
+
break
|
| 213 |
+
|
| 214 |
+
action_idx = llm_decide(obs, client) if use_llm else greedy_decide(obs)
|
| 215 |
+
|
| 216 |
+
await ws.send(json.dumps({
|
| 217 |
+
"type": "step",
|
| 218 |
+
"data": {"stack_index": action_idx},
|
| 219 |
+
}))
|
| 220 |
+
resp = json.loads(await ws.recv())
|
| 221 |
+
payload = resp.get("data", {})
|
| 222 |
+
obs = payload.get("observation", payload)
|
| 223 |
+
reward = float(payload.get("reward", obs.get("last_reward", 0.0)) or obs.get("last_reward", 0.0))
|
| 224 |
+
done = payload.get("done", obs.get("done", False))
|
| 225 |
+
error = payload.get("error", None)
|
| 226 |
+
|
| 227 |
+
rewards.append(reward)
|
| 228 |
+
steps_taken = step
|
| 229 |
+
|
| 230 |
+
log_step(step=step, action=str(action_idx), reward=reward, done=done, error=error)
|
| 231 |
+
|
| 232 |
+
if done:
|
| 233 |
+
break
|
| 234 |
+
|
| 235 |
+
# Fetch final score
|
| 236 |
+
await ws.send(json.dumps({"type": "state"}))
|
| 237 |
+
state_resp = json.loads(await ws.recv())
|
| 238 |
+
state = state_resp.get("data", {})
|
| 239 |
+
score = float(state.get("score", obs.get("score", 0.0)))
|
| 240 |
+
score = min(max(score, 0.0), 1.0)
|
| 241 |
+
|
| 242 |
+
success = score >= SUCCESS_SCORE_THRESHOLD
|
| 243 |
+
|
| 244 |
+
except Exception as exc:
|
| 245 |
+
print(f"[DEBUG] Episode error: {exc}", file=sys.stderr, flush=True)
|
| 246 |
+
|
| 247 |
+
finally:
|
| 248 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 249 |
+
|
| 250 |
+
return score
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
async def run_all(url: str, use_llm: bool = False) -> None:
|
| 254 |
for diff in ["easy", "medium", "hard"]:
|
| 255 |
await run_episode(url, difficulty=diff, use_llm=use_llm)
|
| 256 |
|
| 257 |
|
| 258 |
+
# Entry point
|
| 259 |
+
|
| 260 |
if __name__ == "__main__":
|
| 261 |
+
import argparse
|
| 262 |
+
|
| 263 |
parser = argparse.ArgumentParser(description="Container Port Baseline Agent")
|
| 264 |
parser.add_argument("--url", default=ENV_URL)
|
| 265 |
+
parser.add_argument("--difficulty", default="all",
|
| 266 |
+
choices=["easy", "medium", "hard", "all"])
|
| 267 |
+
parser.add_argument("--use-llm", action="store_true",
|
| 268 |
+
help="Use LLM agent via HF router (requires HF_TOKEN)")
|
| 269 |
args = parser.parse_args()
|
| 270 |
|
| 271 |
if args.difficulty == "all":
|
models.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from openenv.core.env_server.types import Action, Observation
|
| 6 |
+
from pydantic import Field
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ContainerAction(Action):
|
| 10 |
+
"""Place the current container into a stack."""
|
| 11 |
+
|
| 12 |
+
stack_index: int = Field(
|
| 13 |
+
...,
|
| 14 |
+
description="Zero-indexed stack to place the incoming container into",
|
| 15 |
+
ge=0,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ContainerObservation(Observation):
|
| 20 |
+
"""Observation returned after each step."""
|
| 21 |
+
|
| 22 |
+
stack_states: list[list[dict[str, Any]]] = Field(
|
| 23 |
+
..., description="Each stack is a list of {id, priority} dicts (bottom->top)"
|
| 24 |
+
)
|
| 25 |
+
current_container: dict[str, Any] | None = Field(
|
| 26 |
+
None, description="Container to place now: {id, priority, weight}"
|
| 27 |
+
)
|
| 28 |
+
upcoming_retrievals: list[str] = Field(
|
| 29 |
+
default_factory=list,
|
| 30 |
+
description="IDs of next containers to be retrieved (lookahead)",
|
| 31 |
+
)
|
| 32 |
+
rehandle_count: int = Field(0, description="Cumulative rehandles so far")
|
| 33 |
+
step: int = Field(0, description="Steps completed")
|
| 34 |
+
containers_remaining: int = Field(0)
|
| 35 |
+
n_stacks: int = Field(0)
|
| 36 |
+
max_height: int = Field(0)
|
| 37 |
+
difficulty: str = Field("medium")
|
| 38 |
+
last_reward: float = Field(0.0)
|
| 39 |
+
score: float = Field(0.0, description="Normalized score 0.0-1.0")
|
| 40 |
+
done: bool = Field(False)
|
openenv.yaml
CHANGED
|
@@ -1,30 +1,17 @@
|
|
|
|
|
| 1 |
name: container-port-env
|
| 2 |
version: "0.1.0"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
description: >
|
| 4 |
Container terminal yard RL environment. An agent places incoming ship
|
| 5 |
-
containers into stacks
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
visibility. Models real port logistics decision-making.
|
| 9 |
tags:
|
| 10 |
- logistics
|
| 11 |
- planning
|
| 12 |
- real-world
|
| 13 |
- combinatorial-optimization
|
| 14 |
-
sdk: docker
|
| 15 |
-
entry_point: server.server:app
|
| 16 |
-
tools:
|
| 17 |
-
- name: place_container
|
| 18 |
-
description: >
|
| 19 |
-
Place the current incoming container into a specified stack index.
|
| 20 |
-
Priority 1=urgent (retrieved first), 2=normal, 3=low (retrieved last).
|
| 21 |
-
Burying high-priority under low-priority causes rehandle costs.
|
| 22 |
-
input_schema:
|
| 23 |
-
type: object
|
| 24 |
-
properties:
|
| 25 |
-
stack_index:
|
| 26 |
-
type: integer
|
| 27 |
-
description: "Zero-indexed stack to place the container into"
|
| 28 |
-
minimum: 0
|
| 29 |
-
required:
|
| 30 |
-
- stack_index
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
name: container-port-env
|
| 3 |
version: "0.1.0"
|
| 4 |
+
type: standard
|
| 5 |
+
runtime: docker
|
| 6 |
+
app: server.app:app
|
| 7 |
+
port: 7860
|
| 8 |
description: >
|
| 9 |
Container terminal yard RL environment. An agent places incoming ship
|
| 10 |
+
containers into stacks to minimize costly rehandle operations during
|
| 11 |
+
retrieval. Three difficulty levels: easy (6 stacks, lookahead 5),
|
| 12 |
+
medium (8 stacks, lookahead 3), hard (10 stacks, no lookahead).
|
|
|
|
| 13 |
tags:
|
| 14 |
- logistics
|
| 15 |
- planning
|
| 16 |
- real-world
|
| 17 |
- combinatorial-optimization
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pyproject.toml
CHANGED
|
@@ -2,16 +2,25 @@
|
|
| 2 |
requires = ["setuptools>=68"]
|
| 3 |
build-backend = "setuptools.build_meta"
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
[project]
|
| 6 |
name = "openenv-container-port"
|
| 7 |
version = "0.1.0"
|
| 8 |
-
description = "Container yard RL environment
|
| 9 |
requires-python = ">=3.10"
|
| 10 |
dependencies = [
|
|
|
|
| 11 |
"fastapi>=0.110.0",
|
| 12 |
"uvicorn[standard]>=0.29.0",
|
| 13 |
-
"websockets>=12.0",
|
| 14 |
"pydantic>=2.0.0",
|
| 15 |
-
"openenv-core>=0.1.0",
|
| 16 |
"openai>=1.0.0",
|
|
|
|
|
|
|
|
|
|
| 17 |
]
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
requires = ["setuptools>=68"]
|
| 3 |
build-backend = "setuptools.build_meta"
|
| 4 |
|
| 5 |
+
[tool.setuptools]
|
| 6 |
+
packages = ["client", "server"]
|
| 7 |
+
py-modules = ["models"]
|
| 8 |
+
|
| 9 |
[project]
|
| 10 |
name = "openenv-container-port"
|
| 11 |
version = "0.1.0"
|
| 12 |
+
description = "Container yard RL environment - SST x Meta PyTorch OpenEnv Hackathon"
|
| 13 |
requires-python = ">=3.10"
|
| 14 |
dependencies = [
|
| 15 |
+
"openenv-core>=0.2.0",
|
| 16 |
"fastapi>=0.110.0",
|
| 17 |
"uvicorn[standard]>=0.29.0",
|
|
|
|
| 18 |
"pydantic>=2.0.0",
|
|
|
|
| 19 |
"openai>=1.0.0",
|
| 20 |
+
"websockets>=12.0",
|
| 21 |
+
"huggingface_hub>=0.20.0",
|
| 22 |
+
"pytest>=8.0.0",
|
| 23 |
]
|
| 24 |
+
|
| 25 |
+
[project.scripts]
|
| 26 |
+
server = "server.app:main"
|
pytest.ini
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
addopts = -p no:cacheprovider
|
| 3 |
+
testpaths = tests
|
| 4 |
+
python_files = test_openenv_env.py
|
| 5 |
+
norecursedirs = .git .venv __pycache__ .pytest_cache pytest-cache-files-*
|
requirements.txt
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
|
|
| 1 |
fastapi>=0.110.0
|
| 2 |
uvicorn[standard]>=0.29.0
|
| 3 |
-
websockets>=12.0
|
| 4 |
pydantic>=2.0.0
|
| 5 |
-
openenv-core>=0.1.0
|
| 6 |
openai>=1.0.0
|
| 7 |
-
|
| 8 |
huggingface_hub>=0.20.0
|
|
|
|
|
|
| 1 |
+
openenv-core>=0.2.0
|
| 2 |
fastapi>=0.110.0
|
| 3 |
uvicorn[standard]>=0.29.0
|
|
|
|
| 4 |
pydantic>=2.0.0
|
|
|
|
| 5 |
openai>=1.0.0
|
| 6 |
+
websockets>=12.0
|
| 7 |
huggingface_hub>=0.20.0
|
| 8 |
+
pytest>=8.0.0
|
server/app.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI app for Container Port Environment.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 13 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 14 |
+
|
| 15 |
+
os.environ.setdefault("ENABLE_WEB_INTERFACE", "true")
|
| 16 |
+
|
| 17 |
+
from openenv.core.env_server import create_web_interface_app
|
| 18 |
+
import uvicorn
|
| 19 |
+
|
| 20 |
+
from models import ContainerAction, ContainerObservation
|
| 21 |
+
from server.environment import ContainerYardEnvironment
|
| 22 |
+
|
| 23 |
+
app = create_web_interface_app(
|
| 24 |
+
ContainerYardEnvironment,
|
| 25 |
+
ContainerAction,
|
| 26 |
+
ContainerObservation,
|
| 27 |
+
env_name="container-port-env",
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def main() -> None:
|
| 32 |
+
uvicorn.run("server.app:app", host="0.0.0.0", port=7860)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
main()
|
server/environment.py
CHANGED
|
@@ -1,11 +1,20 @@
|
|
|
|
|
|
|
|
| 1 |
import random
|
|
|
|
| 2 |
from dataclasses import dataclass
|
| 3 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
|
|
|
| 6 |
class Container:
|
| 7 |
id: str
|
| 8 |
-
priority: int
|
| 9 |
weight: float
|
| 10 |
|
| 11 |
DIFFICULTY_CONFIG = {
|
|
@@ -35,11 +44,18 @@ DIFFICULTY_CONFIG = {
|
|
| 35 |
},
|
| 36 |
}
|
| 37 |
|
| 38 |
-
class
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
cfg = DIFFICULTY_CONFIG[difficulty]
|
| 44 |
self.n_stacks = cfg["n_stacks"]
|
| 45 |
self.max_height = cfg["max_height"]
|
|
@@ -47,23 +63,18 @@ class ContainerYardEnv:
|
|
| 47 |
self.retrieval_interval = cfg["retrieval_interval"]
|
| 48 |
self.lookahead = cfg["lookahead"]
|
| 49 |
self.priority_weights = cfg["priority_weights"]
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
if self.seed is not None:
|
| 54 |
-
random.seed(self.seed)
|
| 55 |
-
self.stacks: List[List[Container]] = [[] for _ in range(self.n_stacks)]
|
| 56 |
self.rehandle_count = 0
|
| 57 |
-
self.step_count = 0
|
| 58 |
self.total_reward = 0.0
|
| 59 |
self.done = False
|
| 60 |
-
self.manifest:
|
| 61 |
-
self.retrieval_queue:
|
| 62 |
self.retrieval_pointer = 0
|
| 63 |
self.current_idx = 0
|
| 64 |
-
return self._observe(last_reward=0.0)
|
| 65 |
|
| 66 |
-
def _generate_manifest(self) ->
|
| 67 |
containers = []
|
| 68 |
for i in range(self.n_containers):
|
| 69 |
priority = random.choices([1, 2, 3], weights=self.priority_weights)[0]
|
|
@@ -74,7 +85,7 @@ class ContainerYardEnv:
|
|
| 74 |
))
|
| 75 |
return containers
|
| 76 |
|
| 77 |
-
def _generate_retrieval_queue(self) ->
|
| 78 |
ids_by_priority = {1: [], 2: [], 3: []}
|
| 79 |
for c in self.manifest:
|
| 80 |
ids_by_priority[c.priority].append(c.id)
|
|
@@ -83,45 +94,62 @@ class ContainerYardEnv:
|
|
| 83 |
queue = ids_by_priority[1] + ids_by_priority[2] + ids_by_priority[3]
|
| 84 |
return queue
|
| 85 |
|
| 86 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
if self.done:
|
| 88 |
-
return self._observe(0.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
if stack_index < 0 or stack_index >= self.n_stacks:
|
| 91 |
reward = -2.0
|
| 92 |
self.total_reward += reward
|
| 93 |
-
|
|
|
|
| 94 |
|
| 95 |
if len(self.stacks[stack_index]) >= self.max_height:
|
| 96 |
reward = -2.0
|
| 97 |
self.total_reward += reward
|
| 98 |
-
|
|
|
|
| 99 |
|
| 100 |
current = self.manifest[self.current_idx]
|
| 101 |
self.stacks[stack_index].append(current)
|
| 102 |
placement_reward = self._placement_reward(stack_index, current)
|
| 103 |
|
| 104 |
self.current_idx += 1
|
| 105 |
-
self.step_count += 1
|
| 106 |
|
| 107 |
retrieval_cost = 0.0
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
cost, done_ids = self._trigger_retrieval()
|
| 111 |
retrieval_cost = cost
|
| 112 |
-
retrievals_done = done_ids
|
| 113 |
|
| 114 |
reward = placement_reward - retrieval_cost
|
| 115 |
self.total_reward += reward
|
| 116 |
self.done = (self.current_idx >= len(self.manifest))
|
| 117 |
-
|
| 118 |
-
return self._observe(reward), reward, self.done, {
|
| 119 |
-
"rehandles": self.rehandle_count,
|
| 120 |
-
"step": self.step_count,
|
| 121 |
-
"placement_reward": round(placement_reward, 4),
|
| 122 |
-
"retrieval_cost": round(retrieval_cost, 4),
|
| 123 |
-
"retrievals_done": retrievals_done,
|
| 124 |
-
}
|
| 125 |
|
| 126 |
def _placement_reward(self, stack_index: int, container: Container) -> float:
|
| 127 |
# stack_depth = zero-based index of the just-placed container
|
|
@@ -143,7 +171,7 @@ class ContainerYardEnv:
|
|
| 143 |
|
| 144 |
return round(base, 4)
|
| 145 |
|
| 146 |
-
def _trigger_retrieval(self) ->
|
| 147 |
total_cost = 0.0
|
| 148 |
done_ids = []
|
| 149 |
for _ in range(2):
|
|
@@ -166,12 +194,16 @@ class ContainerYardEnv:
|
|
| 166 |
return round(rehandles * 0.4, 4)
|
| 167 |
return 0.0 # container not yet in yard — no penalty
|
| 168 |
|
| 169 |
-
def _get_upcoming_retrievals(self) ->
|
| 170 |
start = self.retrieval_pointer
|
| 171 |
end = min(start + self.lookahead, len(self.retrieval_queue))
|
| 172 |
return self.retrieval_queue[start:end]
|
| 173 |
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
stack_states = []
|
| 176 |
for s in self.stacks:
|
| 177 |
stack_states.append([{"id": c.id, "priority": c.priority} for c in s])
|
|
@@ -181,25 +213,20 @@ class ContainerYardEnv:
|
|
| 181 |
c = self.manifest[self.current_idx]
|
| 182 |
current = {"id": c.id, "priority": c.priority, "weight": c.weight}
|
| 183 |
|
| 184 |
-
return
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
def get_state(self) -> Dict[str, Any]:
|
| 199 |
-
obs = self._observe()
|
| 200 |
-
obs["score"] = self.score()
|
| 201 |
-
obs["total_reward"] = round(self.total_reward, 4)
|
| 202 |
-
return obs
|
| 203 |
|
| 204 |
def score(self) -> float:
|
| 205 |
"""Normalized score in [0.0, 1.0]. Based on actual retrievals attempted."""
|
|
@@ -209,3 +236,19 @@ class ContainerYardEnv:
|
|
| 209 |
return 1.0
|
| 210 |
score = max(0.0, 1.0 - self.rehandle_count / worst_case)
|
| 211 |
return round(min(score, 1.0), 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import random
|
| 4 |
+
import uuid
|
| 5 |
from dataclasses import dataclass
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
from openenv.core.env_server import Environment, State
|
| 9 |
+
from openenv.core.env_server.types import EnvironmentMetadata
|
| 10 |
+
|
| 11 |
+
from models import ContainerAction, ContainerObservation
|
| 12 |
|
| 13 |
+
|
| 14 |
+
@dataclass(slots=True)
|
| 15 |
class Container:
|
| 16 |
id: str
|
| 17 |
+
priority: int
|
| 18 |
weight: float
|
| 19 |
|
| 20 |
DIFFICULTY_CONFIG = {
|
|
|
|
| 44 |
},
|
| 45 |
}
|
| 46 |
|
| 47 |
+
class ContainerYardEnvironment(Environment):
|
| 48 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 49 |
+
|
| 50 |
+
def __init__(self) -> None:
|
| 51 |
+
self._difficulty = "medium"
|
| 52 |
+
self._state = State(episode_id=str(uuid.uuid4()), step_count=0)
|
| 53 |
+
self._init_env("medium", seed=None)
|
| 54 |
+
|
| 55 |
+
def _init_env(self, difficulty: str, seed: int | None) -> None:
|
| 56 |
+
if difficulty not in DIFFICULTY_CONFIG:
|
| 57 |
+
difficulty = "medium"
|
| 58 |
+
self._difficulty = difficulty
|
| 59 |
cfg = DIFFICULTY_CONFIG[difficulty]
|
| 60 |
self.n_stacks = cfg["n_stacks"]
|
| 61 |
self.max_height = cfg["max_height"]
|
|
|
|
| 63 |
self.retrieval_interval = cfg["retrieval_interval"]
|
| 64 |
self.lookahead = cfg["lookahead"]
|
| 65 |
self.priority_weights = cfg["priority_weights"]
|
| 66 |
+
if seed is not None:
|
| 67 |
+
random.seed(seed)
|
| 68 |
+
self.stacks: list[list[Container]] = [[] for _ in range(self.n_stacks)]
|
|
|
|
|
|
|
|
|
|
| 69 |
self.rehandle_count = 0
|
|
|
|
| 70 |
self.total_reward = 0.0
|
| 71 |
self.done = False
|
| 72 |
+
self.manifest: list[Container] = self._generate_manifest()
|
| 73 |
+
self.retrieval_queue: list[str] = self._generate_retrieval_queue()
|
| 74 |
self.retrieval_pointer = 0
|
| 75 |
self.current_idx = 0
|
|
|
|
| 76 |
|
| 77 |
+
def _generate_manifest(self) -> list[Container]:
|
| 78 |
containers = []
|
| 79 |
for i in range(self.n_containers):
|
| 80 |
priority = random.choices([1, 2, 3], weights=self.priority_weights)[0]
|
|
|
|
| 85 |
))
|
| 86 |
return containers
|
| 87 |
|
| 88 |
+
def _generate_retrieval_queue(self) -> list[str]:
|
| 89 |
ids_by_priority = {1: [], 2: [], 3: []}
|
| 90 |
for c in self.manifest:
|
| 91 |
ids_by_priority[c.priority].append(c.id)
|
|
|
|
| 94 |
queue = ids_by_priority[1] + ids_by_priority[2] + ids_by_priority[3]
|
| 95 |
return queue
|
| 96 |
|
| 97 |
+
def reset(
|
| 98 |
+
self,
|
| 99 |
+
seed: int | None = None,
|
| 100 |
+
episode_id: str | None = None,
|
| 101 |
+
**kwargs: Any,
|
| 102 |
+
) -> ContainerObservation:
|
| 103 |
+
difficulty = kwargs.get("difficulty", "medium")
|
| 104 |
+
self._state = State(
|
| 105 |
+
episode_id=episode_id or str(uuid.uuid4()),
|
| 106 |
+
step_count=0,
|
| 107 |
+
)
|
| 108 |
+
self._init_env(difficulty, seed)
|
| 109 |
+
return self._observe(last_reward=0.0)
|
| 110 |
+
|
| 111 |
+
def step(
|
| 112 |
+
self,
|
| 113 |
+
action: ContainerAction | int,
|
| 114 |
+
timeout_s: float | None = None,
|
| 115 |
+
**kwargs: Any,
|
| 116 |
+
) -> ContainerObservation:
|
| 117 |
if self.done:
|
| 118 |
+
return self._observe(0.0)
|
| 119 |
+
|
| 120 |
+
if isinstance(action, int):
|
| 121 |
+
action = ContainerAction(stack_index=action)
|
| 122 |
+
|
| 123 |
+
stack_index = action.stack_index
|
| 124 |
|
| 125 |
if stack_index < 0 or stack_index >= self.n_stacks:
|
| 126 |
reward = -2.0
|
| 127 |
self.total_reward += reward
|
| 128 |
+
self._state.step_count += 1
|
| 129 |
+
return self._observe(reward)
|
| 130 |
|
| 131 |
if len(self.stacks[stack_index]) >= self.max_height:
|
| 132 |
reward = -2.0
|
| 133 |
self.total_reward += reward
|
| 134 |
+
self._state.step_count += 1
|
| 135 |
+
return self._observe(reward)
|
| 136 |
|
| 137 |
current = self.manifest[self.current_idx]
|
| 138 |
self.stacks[stack_index].append(current)
|
| 139 |
placement_reward = self._placement_reward(stack_index, current)
|
| 140 |
|
| 141 |
self.current_idx += 1
|
| 142 |
+
self._state.step_count += 1
|
| 143 |
|
| 144 |
retrieval_cost = 0.0
|
| 145 |
+
if self._state.step_count % self.retrieval_interval == 0:
|
| 146 |
+
cost, _ = self._trigger_retrieval()
|
|
|
|
| 147 |
retrieval_cost = cost
|
|
|
|
| 148 |
|
| 149 |
reward = placement_reward - retrieval_cost
|
| 150 |
self.total_reward += reward
|
| 151 |
self.done = (self.current_idx >= len(self.manifest))
|
| 152 |
+
return self._observe(reward)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
def _placement_reward(self, stack_index: int, container: Container) -> float:
|
| 155 |
# stack_depth = zero-based index of the just-placed container
|
|
|
|
| 171 |
|
| 172 |
return round(base, 4)
|
| 173 |
|
| 174 |
+
def _trigger_retrieval(self) -> tuple[float, list[str]]:
|
| 175 |
total_cost = 0.0
|
| 176 |
done_ids = []
|
| 177 |
for _ in range(2):
|
|
|
|
| 194 |
return round(rehandles * 0.4, 4)
|
| 195 |
return 0.0 # container not yet in yard — no penalty
|
| 196 |
|
| 197 |
+
def _get_upcoming_retrievals(self) -> list[str]:
|
| 198 |
start = self.retrieval_pointer
|
| 199 |
end = min(start + self.lookahead, len(self.retrieval_queue))
|
| 200 |
return self.retrieval_queue[start:end]
|
| 201 |
|
| 202 |
+
@property
|
| 203 |
+
def state(self) -> State:
|
| 204 |
+
return self._state
|
| 205 |
+
|
| 206 |
+
def _observe(self, last_reward: float = 0.0) -> ContainerObservation:
|
| 207 |
stack_states = []
|
| 208 |
for s in self.stacks:
|
| 209 |
stack_states.append([{"id": c.id, "priority": c.priority} for c in s])
|
|
|
|
| 213 |
c = self.manifest[self.current_idx]
|
| 214 |
current = {"id": c.id, "priority": c.priority, "weight": c.weight}
|
| 215 |
|
| 216 |
+
return ContainerObservation(
|
| 217 |
+
stack_states=stack_states,
|
| 218 |
+
current_container=current,
|
| 219 |
+
upcoming_retrievals=self._get_upcoming_retrievals(),
|
| 220 |
+
rehandle_count=self.rehandle_count,
|
| 221 |
+
step=self._state.step_count,
|
| 222 |
+
containers_remaining=len(self.manifest) - self.current_idx,
|
| 223 |
+
n_stacks=self.n_stacks,
|
| 224 |
+
max_height=self.max_height,
|
| 225 |
+
difficulty=self._difficulty,
|
| 226 |
+
last_reward=last_reward,
|
| 227 |
+
score=self.score(),
|
| 228 |
+
done=self.done,
|
| 229 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
def score(self) -> float:
|
| 232 |
"""Normalized score in [0.0, 1.0]. Based on actual retrievals attempted."""
|
|
|
|
| 236 |
return 1.0
|
| 237 |
score = max(0.0, 1.0 - self.rehandle_count / worst_case)
|
| 238 |
return round(min(score, 1.0), 4)
|
| 239 |
+
|
| 240 |
+
def get_state(self) -> dict[str, Any]:
|
| 241 |
+
return self._observe().model_dump()
|
| 242 |
+
|
| 243 |
+
def get_metadata(self) -> EnvironmentMetadata:
|
| 244 |
+
return EnvironmentMetadata(
|
| 245 |
+
name="container-port-env",
|
| 246 |
+
description=(
|
| 247 |
+
"Container terminal yard environment where agents place incoming "
|
| 248 |
+
"containers into stacks to minimize rehandle cost during retrieval."
|
| 249 |
+
),
|
| 250 |
+
version="0.1.0",
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
ContainerYardEnv = ContainerYardEnvironment
|
server/models.py
DELETED
|
@@ -1,36 +0,0 @@
|
|
| 1 |
-
from pydantic import BaseModel, Field
|
| 2 |
-
from typing import List, Optional, Dict, Any
|
| 3 |
-
|
| 4 |
-
class ContainerInfo(BaseModel):
|
| 5 |
-
id: str
|
| 6 |
-
priority: int = Field(..., ge=1, le=3)
|
| 7 |
-
weight: float
|
| 8 |
-
|
| 9 |
-
class StackEntry(BaseModel):
|
| 10 |
-
id: str
|
| 11 |
-
priority: int
|
| 12 |
-
|
| 13 |
-
class ContainerAction(BaseModel):
|
| 14 |
-
stack_index: int = Field(..., description="Which stack (0-indexed) to place the current container into")
|
| 15 |
-
|
| 16 |
-
class ContainerObservation(BaseModel):
|
| 17 |
-
stack_states: List[List[Dict[str, Any]]]
|
| 18 |
-
current_container: Optional[Dict[str, Any]]
|
| 19 |
-
upcoming_retrievals: List[str]
|
| 20 |
-
rehandle_count: int
|
| 21 |
-
step: int
|
| 22 |
-
containers_remaining: int
|
| 23 |
-
n_stacks: int
|
| 24 |
-
max_height: int
|
| 25 |
-
difficulty: str
|
| 26 |
-
last_reward: float
|
| 27 |
-
done: bool
|
| 28 |
-
|
| 29 |
-
class ContainerState(BaseModel):
|
| 30 |
-
stack_states: List[List[Dict[str, Any]]]
|
| 31 |
-
rehandle_count: int
|
| 32 |
-
step: int
|
| 33 |
-
score: float
|
| 34 |
-
difficulty: str
|
| 35 |
-
done: bool
|
| 36 |
-
total_reward: float
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server/server.py
DELETED
|
@@ -1,84 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import uuid
|
| 3 |
-
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 4 |
-
from server.environment import ContainerYardEnv
|
| 5 |
-
from server.models import ContainerAction
|
| 6 |
-
|
| 7 |
-
app = FastAPI(title="Container Port OpenEnv", version="0.1.0")
|
| 8 |
-
|
| 9 |
-
sessions: dict = {}
|
| 10 |
-
|
| 11 |
-
@app.get("/ping")
|
| 12 |
-
def ping():
|
| 13 |
-
return {"status": "ok", "env": "container-port-env"}
|
| 14 |
-
|
| 15 |
-
@app.get("/health")
|
| 16 |
-
def health():
|
| 17 |
-
return {
|
| 18 |
-
"status": "healthy",
|
| 19 |
-
"active_sessions": len(sessions),
|
| 20 |
-
"difficulties": ["easy", "medium", "hard"],
|
| 21 |
-
}
|
| 22 |
-
|
| 23 |
-
@app.websocket("/ws")
|
| 24 |
-
async def websocket_endpoint(websocket: WebSocket):
|
| 25 |
-
await websocket.accept()
|
| 26 |
-
session_id = str(uuid.uuid4())
|
| 27 |
-
sessions[session_id] = ContainerYardEnv(difficulty="medium")
|
| 28 |
-
|
| 29 |
-
try:
|
| 30 |
-
while True:
|
| 31 |
-
raw = await websocket.receive_text()
|
| 32 |
-
msg = json.loads(raw)
|
| 33 |
-
msg_type = msg.get("type")
|
| 34 |
-
env = sessions[session_id]
|
| 35 |
-
|
| 36 |
-
if msg_type == "reset":
|
| 37 |
-
difficulty = msg.get("difficulty", "medium")
|
| 38 |
-
if difficulty not in ["easy", "medium", "hard"]:
|
| 39 |
-
difficulty = "medium"
|
| 40 |
-
sessions[session_id] = ContainerYardEnv(difficulty=difficulty)
|
| 41 |
-
env = sessions[session_id]
|
| 42 |
-
obs = env.reset()
|
| 43 |
-
await websocket.send_text(json.dumps({
|
| 44 |
-
"type": "reset",
|
| 45 |
-
"observation": obs,
|
| 46 |
-
"reward": 0.0,
|
| 47 |
-
"done": False,
|
| 48 |
-
"session_id": session_id,
|
| 49 |
-
}))
|
| 50 |
-
|
| 51 |
-
elif msg_type == "step":
|
| 52 |
-
try:
|
| 53 |
-
action = ContainerAction(**msg["action"])
|
| 54 |
-
obs, reward, done, info = env.step(action.stack_index)
|
| 55 |
-
await websocket.send_text(json.dumps({
|
| 56 |
-
"type": "step",
|
| 57 |
-
"observation": obs,
|
| 58 |
-
"reward": reward,
|
| 59 |
-
"done": done,
|
| 60 |
-
"info": info,
|
| 61 |
-
}))
|
| 62 |
-
except Exception as e:
|
| 63 |
-
await websocket.send_text(json.dumps({
|
| 64 |
-
"type": "error",
|
| 65 |
-
"message": str(e),
|
| 66 |
-
}))
|
| 67 |
-
|
| 68 |
-
elif msg_type == "state":
|
| 69 |
-
state = env.get_state()
|
| 70 |
-
await websocket.send_text(json.dumps({
|
| 71 |
-
"type": "state",
|
| 72 |
-
"state": state,
|
| 73 |
-
}))
|
| 74 |
-
|
| 75 |
-
else:
|
| 76 |
-
await websocket.send_text(json.dumps({
|
| 77 |
-
"type": "error",
|
| 78 |
-
"message": f"Unknown message type: {msg_type}",
|
| 79 |
-
}))
|
| 80 |
-
|
| 81 |
-
except WebSocketDisconnect:
|
| 82 |
-
pass
|
| 83 |
-
finally:
|
| 84 |
-
sessions.pop(session_id, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
sys.dont_write_bytecode = True
|
| 5 |
+
|
| 6 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 7 |
+
|
| 8 |
+
if str(ROOT) not in sys.path:
|
| 9 |
+
sys.path.insert(0, str(ROOT))
|
tests/test_env.py
DELETED
|
@@ -1,110 +0,0 @@
|
|
| 1 |
-
import pytest
|
| 2 |
-
from server.environment import ContainerYardEnv, DIFFICULTY_CONFIG
|
| 3 |
-
|
| 4 |
-
@pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
|
| 5 |
-
def test_reset_returns_valid_obs(difficulty):
|
| 6 |
-
env = ContainerYardEnv(difficulty=difficulty, seed=42)
|
| 7 |
-
obs = env.reset()
|
| 8 |
-
cfg = DIFFICULTY_CONFIG[difficulty]
|
| 9 |
-
assert len(obs["stack_states"]) == cfg["n_stacks"]
|
| 10 |
-
assert obs["current_container"] is not None
|
| 11 |
-
assert obs["step"] == 0
|
| 12 |
-
assert obs["rehandle_count"] == 0
|
| 13 |
-
assert obs["difficulty"] == difficulty
|
| 14 |
-
assert obs["done"] == False
|
| 15 |
-
|
| 16 |
-
@pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
|
| 17 |
-
def test_step_valid_action(difficulty):
|
| 18 |
-
env = ContainerYardEnv(difficulty=difficulty, seed=42)
|
| 19 |
-
env.reset()
|
| 20 |
-
obs, reward, done, info = env.step(0)
|
| 21 |
-
assert isinstance(reward, float)
|
| 22 |
-
assert obs["step"] == 1
|
| 23 |
-
assert len(obs["stack_states"][0]) == 1
|
| 24 |
-
assert "rehandles" in info
|
| 25 |
-
|
| 26 |
-
@pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
|
| 27 |
-
def test_step_invalid_stack_index(difficulty):
|
| 28 |
-
env = ContainerYardEnv(difficulty=difficulty, seed=42)
|
| 29 |
-
env.reset()
|
| 30 |
-
obs, reward, done, info = env.step(999)
|
| 31 |
-
assert reward == -2.0
|
| 32 |
-
assert "error" in info
|
| 33 |
-
assert done == False
|
| 34 |
-
|
| 35 |
-
@pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
|
| 36 |
-
def test_full_episode_completes(difficulty):
|
| 37 |
-
env = ContainerYardEnv(difficulty=difficulty, seed=42)
|
| 38 |
-
env.reset()
|
| 39 |
-
done = False
|
| 40 |
-
steps = 0
|
| 41 |
-
cfg = DIFFICULTY_CONFIG[difficulty]
|
| 42 |
-
n_stacks = cfg["n_stacks"]
|
| 43 |
-
max_height = cfg["max_height"]
|
| 44 |
-
while not done:
|
| 45 |
-
stacks = env._observe()["stack_states"]
|
| 46 |
-
chosen = 0
|
| 47 |
-
for i in range(n_stacks):
|
| 48 |
-
if len(stacks[i]) < max_height:
|
| 49 |
-
chosen = i
|
| 50 |
-
break
|
| 51 |
-
_, _, done, _ = env.step(chosen)
|
| 52 |
-
steps += 1
|
| 53 |
-
assert steps < 1000, "Episode did not complete in time"
|
| 54 |
-
assert done
|
| 55 |
-
|
| 56 |
-
@pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
|
| 57 |
-
def test_score_in_range(difficulty):
|
| 58 |
-
env = ContainerYardEnv(difficulty=difficulty, seed=42)
|
| 59 |
-
env.reset()
|
| 60 |
-
done = False
|
| 61 |
-
cfg = DIFFICULTY_CONFIG[difficulty]
|
| 62 |
-
n_stacks = cfg["n_stacks"]
|
| 63 |
-
max_height = cfg["max_height"]
|
| 64 |
-
while not done:
|
| 65 |
-
stacks = env._observe()["stack_states"]
|
| 66 |
-
chosen = 0
|
| 67 |
-
for i in range(n_stacks):
|
| 68 |
-
if len(stacks[i]) < max_height:
|
| 69 |
-
chosen = i
|
| 70 |
-
break
|
| 71 |
-
_, _, done, _ = env.step(chosen)
|
| 72 |
-
score = env.score()
|
| 73 |
-
assert 0.0 <= score <= 1.0
|
| 74 |
-
|
| 75 |
-
def test_lookahead_visibility():
|
| 76 |
-
easy_env = ContainerYardEnv(difficulty="easy", seed=42)
|
| 77 |
-
hard_env = ContainerYardEnv(difficulty="hard", seed=42)
|
| 78 |
-
easy_obs = easy_env.reset()
|
| 79 |
-
hard_obs = hard_env.reset()
|
| 80 |
-
assert len(easy_obs["upcoming_retrievals"]) > len(hard_obs["upcoming_retrievals"])
|
| 81 |
-
assert len(hard_obs["upcoming_retrievals"]) == 0
|
| 82 |
-
|
| 83 |
-
def test_reward_is_dense():
|
| 84 |
-
env = ContainerYardEnv(difficulty="medium", seed=42)
|
| 85 |
-
env.reset()
|
| 86 |
-
rewards = []
|
| 87 |
-
done = False
|
| 88 |
-
step = 0
|
| 89 |
-
while not done and step < 20:
|
| 90 |
-
stacks = env._observe()["stack_states"]
|
| 91 |
-
chosen = step % 8
|
| 92 |
-
if len(stacks[chosen]) >= 5:
|
| 93 |
-
chosen = 0
|
| 94 |
-
_, r, done, _ = env.step(chosen)
|
| 95 |
-
rewards.append(r)
|
| 96 |
-
step += 1
|
| 97 |
-
nonzero = sum(1 for r in rewards if abs(r) > 1e-6)
|
| 98 |
-
assert nonzero >= len(rewards) * 0.5, f"Too many zero rewards: {rewards}"
|
| 99 |
-
|
| 100 |
-
def test_no_double_retrieval():
|
| 101 |
-
"""Retrieval pointer advances correctly — no container retrieved twice."""
|
| 102 |
-
env = ContainerYardEnv(difficulty="easy", seed=42)
|
| 103 |
-
env.reset()
|
| 104 |
-
seen_ids = set()
|
| 105 |
-
for _ in range(env.n_containers):
|
| 106 |
-
if env.done:
|
| 107 |
-
break
|
| 108 |
-
env.step(0 if len(env.stacks[0]) < env.max_height else 1)
|
| 109 |
-
# retrieval_pointer should be <= queue length
|
| 110 |
-
assert env.retrieval_pointer <= len(env.retrieval_queue)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_openenv_env.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from fastapi.testclient import TestClient
|
| 3 |
+
|
| 4 |
+
from models import ContainerAction
|
| 5 |
+
from server.app import app
|
| 6 |
+
from server.environment import ContainerYardEnvironment, DIFFICULTY_CONFIG
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def as_dict(observation):
|
| 10 |
+
return observation.model_dump() if hasattr(observation, "model_dump") else observation
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Unit tests: pure environment logic (no HTTP)
|
| 14 |
+
|
| 15 |
+
@pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
|
| 16 |
+
def test_reset_returns_valid_obs(difficulty):
|
| 17 |
+
env = ContainerYardEnvironment()
|
| 18 |
+
obs = as_dict(env.reset(difficulty=difficulty, seed=42))
|
| 19 |
+
cfg = DIFFICULTY_CONFIG[difficulty]
|
| 20 |
+
assert len(obs["stack_states"]) == cfg["n_stacks"]
|
| 21 |
+
assert obs["current_container"] is not None
|
| 22 |
+
assert obs["step"] == 0
|
| 23 |
+
assert obs["rehandle_count"] == 0
|
| 24 |
+
assert obs["difficulty"] == difficulty
|
| 25 |
+
assert obs["done"] is False
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
|
| 29 |
+
def test_step_valid_action(difficulty):
|
| 30 |
+
env = ContainerYardEnvironment()
|
| 31 |
+
env.reset(difficulty=difficulty, seed=42)
|
| 32 |
+
obs = as_dict(env.step(ContainerAction(stack_index=0)))
|
| 33 |
+
assert obs["step"] == 1
|
| 34 |
+
assert len(obs["stack_states"][0]) == 1
|
| 35 |
+
assert isinstance(obs["last_reward"], float)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
|
| 39 |
+
def test_step_invalid_action_penalized(difficulty):
|
| 40 |
+
env = ContainerYardEnvironment()
|
| 41 |
+
env.reset(difficulty=difficulty, seed=42)
|
| 42 |
+
obs = as_dict(env.step(ContainerAction(stack_index=999)))
|
| 43 |
+
assert obs["last_reward"] == -2.0
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_score_in_range():
|
| 47 |
+
env = ContainerYardEnvironment()
|
| 48 |
+
env.reset(difficulty="medium", seed=42)
|
| 49 |
+
done = False
|
| 50 |
+
while not done:
|
| 51 |
+
stacks = as_dict(env._observe())["stack_states"]
|
| 52 |
+
chosen = next(
|
| 53 |
+
(i for i, stack in enumerate(stacks) if len(stack) < env.max_height), 0
|
| 54 |
+
)
|
| 55 |
+
obs = as_dict(env.step(ContainerAction(stack_index=chosen)))
|
| 56 |
+
done = obs["done"]
|
| 57 |
+
assert 0.0 <= env.score() <= 1.0
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
|
| 61 |
+
def test_full_episode_completes(difficulty):
|
| 62 |
+
env = ContainerYardEnvironment()
|
| 63 |
+
env.reset(difficulty=difficulty, seed=42)
|
| 64 |
+
cfg = DIFFICULTY_CONFIG[difficulty]
|
| 65 |
+
done = False
|
| 66 |
+
steps = 0
|
| 67 |
+
while not done:
|
| 68 |
+
stacks = as_dict(env._observe())["stack_states"]
|
| 69 |
+
chosen = next(
|
| 70 |
+
(i for i, s in enumerate(stacks) if len(s) < cfg["max_height"]), 0
|
| 71 |
+
)
|
| 72 |
+
obs = as_dict(env.step(ContainerAction(stack_index=chosen)))
|
| 73 |
+
done = obs["done"]
|
| 74 |
+
steps += 1
|
| 75 |
+
assert steps < 500, "Episode did not complete"
|
| 76 |
+
assert done is True
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def test_lookahead_visibility():
|
| 80 |
+
easy_env = ContainerYardEnvironment()
|
| 81 |
+
hard_env = ContainerYardEnvironment()
|
| 82 |
+
easy_obs = as_dict(easy_env.reset(difficulty="easy", seed=42))
|
| 83 |
+
hard_obs = as_dict(hard_env.reset(difficulty="hard", seed=42))
|
| 84 |
+
assert len(easy_obs["upcoming_retrievals"]) > len(hard_obs["upcoming_retrievals"])
|
| 85 |
+
assert len(hard_obs["upcoming_retrievals"]) == 0
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def test_reward_is_dense():
|
| 89 |
+
env = ContainerYardEnvironment()
|
| 90 |
+
env.reset(difficulty="medium", seed=42)
|
| 91 |
+
rewards = []
|
| 92 |
+
done = False
|
| 93 |
+
step = 0
|
| 94 |
+
while not done and step < 20:
|
| 95 |
+
stacks = as_dict(env._observe())["stack_states"]
|
| 96 |
+
chosen = step % env.n_stacks
|
| 97 |
+
if len(stacks[chosen]) >= env.max_height:
|
| 98 |
+
chosen = 0
|
| 99 |
+
obs = as_dict(env.step(ContainerAction(stack_index=chosen)))
|
| 100 |
+
rewards.append(obs["last_reward"])
|
| 101 |
+
done = obs["done"]
|
| 102 |
+
step += 1
|
| 103 |
+
nonzero = sum(1 for r in rewards if abs(r) > 1e-6)
|
| 104 |
+
assert nonzero >= len(rewards) * 0.5, f"Too many zero rewards: {rewards}"
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def test_no_double_retrieval():
|
| 108 |
+
env = ContainerYardEnvironment()
|
| 109 |
+
env.reset(difficulty="easy", seed=42)
|
| 110 |
+
for _ in range(env.n_containers):
|
| 111 |
+
if env.done:
|
| 112 |
+
break
|
| 113 |
+
stacks = env.stacks
|
| 114 |
+
chosen = next(
|
| 115 |
+
(i for i, s in enumerate(stacks) if len(s) < env.max_height), 0
|
| 116 |
+
)
|
| 117 |
+
env.step(ContainerAction(stack_index=chosen))
|
| 118 |
+
assert env.retrieval_pointer <= len(env.retrieval_queue)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# HTTP integration tests
|
| 122 |
+
|
| 123 |
+
def test_health_route():
|
| 124 |
+
client = TestClient(app)
|
| 125 |
+
resp = client.get("/health")
|
| 126 |
+
assert resp.status_code == 200
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def test_web_ui_route():
|
| 130 |
+
client = TestClient(app, follow_redirects=True)
|
| 131 |
+
resp = client.get("/web")
|
| 132 |
+
assert resp.status_code == 200
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def test_http_reset_returns_observation():
|
| 136 |
+
client = TestClient(app)
|
| 137 |
+
resp = client.post("/reset", json={"difficulty": "easy"})
|
| 138 |
+
assert resp.status_code == 200
|
| 139 |
+
body = resp.json()
|
| 140 |
+
obs = body.get("observation", body)
|
| 141 |
+
assert obs.get("difficulty") == "easy"
|
| 142 |
+
assert obs.get("step") == 0
|
| 143 |
+
assert obs.get("containers_remaining") == DIFFICULTY_CONFIG["easy"]["n_containers"]
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def test_http_reset_then_step_preserves_state():
|
| 147 |
+
client = TestClient(app)
|
| 148 |
+
|
| 149 |
+
reset_resp = client.post("/web/reset", json={"difficulty": "easy"})
|
| 150 |
+
assert reset_resp.status_code == 200
|
| 151 |
+
reset_body = reset_resp.json()
|
| 152 |
+
|
| 153 |
+
session_id = reset_body.get("session_id") or reset_body.get("id")
|
| 154 |
+
obs_after_reset = reset_body.get("observation", reset_body)
|
| 155 |
+
assert obs_after_reset.get("step") == 0
|
| 156 |
+
n_containers = DIFFICULTY_CONFIG["easy"]["n_containers"]
|
| 157 |
+
assert obs_after_reset.get("containers_remaining") == n_containers
|
| 158 |
+
|
| 159 |
+
step_payload = {"action": {"stack_index": 0}}
|
| 160 |
+
if session_id:
|
| 161 |
+
step_payload["session_id"] = session_id
|
| 162 |
+
|
| 163 |
+
step_resp = client.post("/web/step", json=step_payload)
|
| 164 |
+
assert step_resp.status_code == 200
|
| 165 |
+
step_body = step_resp.json()
|
| 166 |
+
obs_after_step = step_body.get("observation", step_body)
|
| 167 |
+
|
| 168 |
+
assert obs_after_step.get("step") == 1
|
| 169 |
+
assert obs_after_step.get("containers_remaining") == n_containers - 1
|
| 170 |
+
assert len(obs_after_step["stack_states"][0]) == 1
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|