Shabista Sehar commited on
Commit
11f9523
·
0 Parent(s):

Initial: Container Port OpenEnv

Browse files
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ ENV PYTHONPATH=/app
11
+
12
+ EXPOSE 7860
13
+
14
+ CMD ["uvicorn", "server.server:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Container Port Environment
2
+
3
+ An OpenEnv-compatible RL environment for container yard management at a shipping terminal.
4
+
5
+ ## Task
6
+
7
+ A ship arrives with N containers (priority 1=urgent, 2=normal, 3=low). The agent places each into
8
+ stacks. At regular intervals, specific containers are retrieved. If a target is buried under others,
9
+ each container above it is a **rehandle** — expensive in real port operations.
10
+
11
+ **Goal: minimize total rehandle operations across the episode.**
12
+
13
+ ## Difficulty Levels
14
+
15
+ | Parameter | Easy | Medium | Hard |
16
+ |--------------------|----------|----------|----------|
17
+ | Stacks | 6 | 8 | 10 |
18
+ | Max stack height | 4 | 5 | 6 |
19
+ | Containers | 20 | 35 | 50 |
20
+ | Retrieval interval | every 5 | every 5 | every 4 |
21
+ | Lookahead shown | 5 | 3 | 0 |
22
+
23
+ ## Reward
24
+
25
+ | Event | Reward |
26
+ |---|---|
27
+ | Accessible placement of priority-1 (near top) | up to +0.45 |
28
+ | General placement | +0.03 to +0.30 |
29
+ | Burying high-priority under low-priority | -0.10 to -0.20 |
30
+ | Invalid action (full stack / bad index) | -2.0 |
31
+ | Each rehandle at retrieval time | -0.40 |
32
+
33
+ ## Score
34
+
35
+ `score = 1.0 - (actual_rehandles / worst_case_rehandles)`, in [0.0, 1.0].
36
+
37
+ ## Setup
38
+ ```bash
39
+ pip install -r requirements.txt
40
+ uvicorn server.server:app --host 0.0.0.0 --port 7860
41
+ ```
42
+
43
+ ## Run inference
44
+ ```bash
45
+ # Greedy agent, all difficulties
46
+ python inference.py --difficulty all
47
+
48
+ # LLM agent (requires HF token in env)
49
+ export HF_TOKEN=hf_your_token_here
50
+ python inference.py --use-llm --difficulty all
51
+
52
+ # Against deployed HF Space
53
+ python inference.py --url https://YOUR_USERNAME-container-port-env.hf.space --difficulty all
54
+ ```
55
+
56
+ ## Docker
57
+ ```bash
58
+ docker build -t container-port-env .
59
+ docker run -p 7860:7860 container-port-env
60
+ ```
61
+
62
+ ## API
63
+
64
+ - `GET /ping` — health check
65
+ - `GET /health` — server stats
66
+ - `WS /ws` — WebSocket interface
67
+
68
+ WebSocket messages:
69
+ - `{"type": "reset", "difficulty": "easy"}` — start episode
70
+ - `{"type": "step", "action": {"stack_index": 2}}` — place container
71
+ - `{"type": "state"}` — get full state with score
client/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
client/container_env.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import websockets
3
+ from typing import Any, Dict, Tuple
4
+
5
+ class ContainerEnvClient:
6
+ """Async client for Container Port OpenEnv."""
7
+
8
+ def __init__(self, base_url: str = "http://localhost:7860"):
9
+ ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
10
+ self.ws_url = ws_url.rstrip("/") + "/ws"
11
+ self._ws = None
12
+
13
+ async def __aenter__(self):
14
+ self._ws = await websockets.connect(self.ws_url)
15
+ return self
16
+
17
+ async def __aexit__(self, *args):
18
+ if self._ws:
19
+ await self._ws.close()
20
+
21
+ async def reset(self, difficulty: str = "medium") -> Dict[str, Any]:
22
+ await self._ws.send(json.dumps({"type": "reset", "difficulty": difficulty}))
23
+ resp = json.loads(await self._ws.recv())
24
+ return resp["observation"]
25
+
26
+ async def step(self, stack_index: int) -> Tuple[Dict, float, bool, Dict]:
27
+ await self._ws.send(json.dumps({
28
+ "type": "step",
29
+ "action": {"stack_index": stack_index}
30
+ }))
31
+ resp = json.loads(await self._ws.recv())
32
+ return resp["observation"], resp["reward"], resp["done"], resp.get("info", {})
33
+
34
+ async def state(self) -> Dict[str, Any]:
35
+ await self._ws.send(json.dumps({"type": "state"}))
36
+ resp = json.loads(await self._ws.recv())
37
+ return resp["state"]
inference.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Container Port OpenEnv — Baseline Inference Script
4
+ SST x Meta PyTorch OpenEnv Hackathon
5
+
6
+ Required environment variables (or set below):
7
+ HF_TOKEN - Your Hugging Face token
8
+ API_BASE_URL - LLM API endpoint (default: https://router.huggingface.co/v1)
9
+ MODEL_NAME - Model identifier (default: meta-llama/Llama-3.1-8B-Instruct)
10
+
11
+ Usage:
12
+ python inference.py
13
+ python inference.py --url https://YOUR_USERNAME-container-port-env.hf.space --difficulty all
14
+ python inference.py --difficulty easy
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ import json
20
+ import asyncio
21
+ import argparse
22
+ import websockets
23
+ from openai import OpenAI
24
+
25
+ # Required configuration variables
26
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
27
+ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
28
+ HF_TOKEN = os.getenv("HF_TOKEN", "") # set your HF token here or via env var
29
+ #
30
+
31
+ ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
32
+
33
+
34
+ def _llm_client() -> OpenAI:
35
+ """Return an OpenAI-compatible client pointed at HF Inference Router."""
36
+ return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
37
+
38
+
39
+ def greedy_decide(obs: dict) -> int:
40
+ """
41
+ Greedy heuristic agent — no LLM call.
42
+ Scores each valid stack by accessibility and priority compatibility.
43
+ """
44
+ stacks = obs["stack_states"]
45
+ current = obs.get("current_container")
46
+ max_height = obs["max_height"]
47
+ upcoming = set(obs.get("upcoming_retrievals", []))
48
+
49
+ if current is None:
50
+ return 0
51
+
52
+ cur_priority = current["priority"]
53
+ best_stack, best_score = -1, float("-inf")
54
+
55
+ for i, stack in enumerate(stacks):
56
+ depth = len(stack)
57
+ if depth >= max_height:
58
+ continue
59
+
60
+ score = 0.0
61
+ accessibility = (max_height - depth) / max_height
62
+ score += accessibility * (4 - cur_priority)
63
+
64
+ if depth > 0:
65
+ top_priority = stack[-1]["priority"]
66
+ if cur_priority > top_priority:
67
+ score -= 10.0 * (cur_priority - top_priority)
68
+ elif cur_priority < top_priority:
69
+ score += 3.0
70
+
71
+ if current["id"] in upcoming:
72
+ score += 5.0 * accessibility
73
+
74
+ if depth > 0:
75
+ score += 0.5
76
+
77
+ if score > best_score:
78
+ best_score = score
79
+ best_stack = i
80
+
81
+ if best_stack == -1:
82
+ for i, stack in enumerate(stacks):
83
+ if len(stack) < max_height:
84
+ return i
85
+ return best_stack
86
+
87
+
88
+ def llm_decide(obs: dict) -> int:
89
+ """Use HF-hosted LLM via OpenAI-compatible client to choose a stack."""
90
+ stacks = obs["stack_states"]
91
+ current = obs.get("current_container")
92
+ n_stacks = obs["n_stacks"]
93
+ max_height = obs["max_height"]
94
+ upcoming = obs.get("upcoming_retrievals", [])
95
+ difficulty = obs.get("difficulty", "medium")
96
+
97
+ stack_lines = []
98
+ for i, stack in enumerate(stacks):
99
+ if not stack:
100
+ stack_lines.append(f" Stack {i}: EMPTY (0/{max_height})")
101
+ else:
102
+ contents = ", ".join(f"{c['id']}(p{c['priority']})" for c in stack)
103
+ stack_lines.append(
104
+ f" Stack {i}: [{contents}] depth={len(stack)}/{max_height},"
105
+ f" top=priority-{stack[-1]['priority']}"
106
+ )
107
+
108
+ prompt = (
109
+ f"You are an expert container yard planner.\n"
110
+ f"TASK: Place the incoming container into a stack to MINIMIZE future rehandle operations.\n"
111
+ f"RULE: When a container is retrieved, every container ON TOP of it must be moved (rehandle).\n"
112
+ f"Priority 1=URGENT (retrieved first), 2=Normal, 3=Low (retrieved last).\n\n"
113
+ f"DIFFICULTY: {difficulty}\n"
114
+ f"UPCOMING RETRIEVALS (next to be retrieved, in order): "
115
+ f"{upcoming if upcoming else 'Unknown (hard mode)'}\n\n"
116
+ f"CONTAINER TO PLACE: id={current['id']}, priority={current['priority']}, "
117
+ f"weight={current['weight']}kg\n\n"
118
+ f"STACK STATES (bottomtop):\n" + "\n".join(stack_lines) + "\n\n"
119
+ f"Respond with ONLY valid JSON: {{\"stack_index\": <integer 0-{n_stacks-1}>}}"
120
+ )
121
+
122
+ try:
123
+ client = _llm_client()
124
+ response = client.chat.completions.create(
125
+ model=MODEL_NAME,
126
+ max_tokens=64,
127
+ temperature=0.0,
128
+ messages=[{"role": "user", "content": prompt}],
129
+ )
130
+ text = response.choices[0].message.content.strip()
131
+ # strip markdown fences if model wraps in ```json ... ```
132
+ if "```" in text:
133
+ text = text.split("```")[1]
134
+ if text.startswith("json"):
135
+ text = text[4:]
136
+ decision = json.loads(text.strip())
137
+ idx = int(decision["stack_index"])
138
+ if 0 <= idx < n_stacks and len(obs["stack_states"][idx]) < max_height:
139
+ return idx
140
+ except Exception as e:
141
+ print(f" [LLM fallback: {e}]", file=sys.stderr)
142
+
143
+ return greedy_decide(obs)
144
+
145
+
146
+ async def run_episode(url: str, difficulty: str = "medium", use_llm: bool = False) -> float:
147
+ ws_url = url.replace("http://", "ws://").replace("https://", "wss://")
148
+ if not ws_url.endswith("/ws"):
149
+ ws_url = ws_url.rstrip("/") + "/ws"
150
+
151
+ # [START] log
152
+ print(json.dumps({"type": "[START]", "task": difficulty, "difficulty": difficulty,
153
+ "env_url": url, "model": MODEL_NAME if use_llm else "greedy"}))
154
+ sys.stdout.flush()
155
+
156
+ total_reward = 0.0
157
+ step = 0
158
+
159
+ async with websockets.connect(ws_url) as ws:
160
+ await ws.send(json.dumps({"type": "reset", "difficulty": difficulty}))
161
+ resp = json.loads(await ws.recv())
162
+ obs = resp["observation"]
163
+
164
+ while not obs.get("done", False):
165
+ action_idx = llm_decide(obs) if use_llm else greedy_decide(obs)
166
+
167
+ await ws.send(json.dumps({"type": "step", "action": {"stack_index": action_idx}}))
168
+ resp = json.loads(await ws.recv())
169
+ obs = resp["observation"]
170
+ reward = resp["reward"]
171
+ done = resp["done"]
172
+ total_reward += reward
173
+ step += 1
174
+
175
+ # [STEP] log
176
+ print(json.dumps({
177
+ "type": "[STEP]",
178
+ "step": step,
179
+ "action": action_idx,
180
+ "reward": round(reward, 4),
181
+ "total_reward": round(total_reward, 4),
182
+ "done": done,
183
+ "rehandle_count": obs["rehandle_count"],
184
+ }))
185
+ sys.stdout.flush()
186
+
187
+ # fetch final state for score
188
+ await ws.send(json.dumps({"type": "state"}))
189
+ state_resp = json.loads(await ws.recv())
190
+ state = state_resp["state"]
191
+
192
+ final_score = state.get("score", 0.0)
193
+
194
+ # [END] log
195
+ print(json.dumps({
196
+ "type": "[END]",
197
+ "task": difficulty,
198
+ "difficulty": difficulty,
199
+ "total_reward": round(total_reward, 4),
200
+ "final_score": final_score,
201
+ "total_steps": step,
202
+ "rehandle_count": state.get("rehandle_count", 0),
203
+ }))
204
+ sys.stdout.flush()
205
+
206
+ return final_score
207
+
208
+
209
+ async def run_all(url: str, use_llm: bool = False):
210
+ for diff in ["easy", "medium", "hard"]:
211
+ await run_episode(url, difficulty=diff, use_llm=use_llm)
212
+
213
+
214
+ if __name__ == "__main__":
215
+ parser = argparse.ArgumentParser(description="Container Port Baseline Agent")
216
+ parser.add_argument("--url", default=ENV_URL)
217
+ parser.add_argument("--difficulty", default="all", choices=["easy", "medium", "hard", "all"])
218
+ parser.add_argument("--use-llm", action="store_true")
219
+ args = parser.parse_args()
220
+
221
+ if args.difficulty == "all":
222
+ asyncio.run(run_all(args.url, use_llm=args.use_llm))
223
+ else:
224
+ asyncio.run(run_episode(args.url, difficulty=args.difficulty, use_llm=args.use_llm))
openenv.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: container-port-env
2
+ version: "0.1.0"
3
+ description: >
4
+ Container terminal yard RL environment. An agent places incoming ship
5
+ containers into stacks of limited height to minimize costly rehandle
6
+ operations during retrieval. Features 3 difficulty levels (easy/medium/hard)
7
+ with different stack configurations, retrieval frequencies, and lookahead
8
+ visibility. Models real port logistics decision-making.
9
+ tags:
10
+ - logistics
11
+ - planning
12
+ - real-world
13
+ - combinatorial-optimization
14
+ sdk: docker
15
+ entry_point: server.server:app
16
+ tools:
17
+ - name: place_container
18
+ description: >
19
+ Place the current incoming container into a specified stack index.
20
+ Priority 1=urgent (retrieved first), 2=normal, 3=low (retrieved last).
21
+ Burying high-priority under low-priority causes rehandle costs.
22
+ input_schema:
23
+ type: object
24
+ properties:
25
+ stack_index:
26
+ type: integer
27
+ description: "Zero-indexed stack to place the container into"
28
+ minimum: 0
29
+ required:
30
+ - stack_index
pyproject.toml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=68"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "openenv-container-port"
7
+ version = "0.1.0"
8
+ description = "Container yard RL environment for OpenEnv hackathon"
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ "fastapi>=0.110.0",
12
+ "uvicorn[standard]>=0.29.0",
13
+ "websockets>=12.0",
14
+ "pydantic>=2.0.0",
15
+ "openenv-core>=0.1.0",
16
+ "openai>=1.0.0",
17
+ ]
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.110.0
2
+ uvicorn[standard]>=0.29.0
3
+ websockets>=12.0
4
+ pydantic>=2.0.0
5
+ openenv-core>=0.1.0
6
+ openai>=1.0.0
7
+ pytest>=8.0.0
8
+ huggingface_hub>=0.20.0
server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
server/__pycache__/__init__.cpython-313.pyc.1828594663216 ADDED
Binary file (145 Bytes). View file
 
server/__pycache__/__init__.cpython-313.pyc.2347090562224 ADDED
Binary file (145 Bytes). View file
 
server/__pycache__/environment.cpython-313.pyc.1828594650032 ADDED
Binary file (11.4 kB). View file
 
server/__pycache__/environment.cpython-313.pyc.2347106154928 ADDED
Binary file (11.4 kB). View file
 
server/__pycache__/models.cpython-313.pyc.2347106454640 ADDED
Binary file (2.34 kB). View file
 
server/__pycache__/server.cpython-313.pyc.2347090514912 ADDED
Binary file (3.88 kB). View file
 
server/environment.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Dict, Any, Tuple
4
+
5
+ @dataclass
6
+ class Container:
7
+ id: str
8
+ priority: int # 1=urgent, 2=normal, 3=low
9
+ weight: float
10
+
11
+ DIFFICULTY_CONFIG = {
12
+ "easy": {
13
+ "n_stacks": 6,
14
+ "max_height": 4,
15
+ "n_containers": 20,
16
+ "retrieval_interval": 5,
17
+ "lookahead": 5,
18
+ "priority_weights": [0.4, 0.4, 0.2],
19
+ },
20
+ "medium": {
21
+ "n_stacks": 8,
22
+ "max_height": 5,
23
+ "n_containers": 35,
24
+ "retrieval_interval": 5,
25
+ "lookahead": 3,
26
+ "priority_weights": [0.33, 0.34, 0.33],
27
+ },
28
+ "hard": {
29
+ "n_stacks": 10,
30
+ "max_height": 6,
31
+ "n_containers": 50,
32
+ "retrieval_interval": 4,
33
+ "lookahead": 0,
34
+ "priority_weights": [0.25, 0.35, 0.40],
35
+ },
36
+ }
37
+
38
+ class ContainerYardEnv:
39
+ def __init__(self, difficulty: str = "medium", seed: Optional[int] = None):
40
+ assert difficulty in DIFFICULTY_CONFIG, f"difficulty must be one of {list(DIFFICULTY_CONFIG.keys())}"
41
+ self.difficulty = difficulty
42
+ self.seed = seed
43
+ cfg = DIFFICULTY_CONFIG[difficulty]
44
+ self.n_stacks = cfg["n_stacks"]
45
+ self.max_height = cfg["max_height"]
46
+ self.n_containers = cfg["n_containers"]
47
+ self.retrieval_interval = cfg["retrieval_interval"]
48
+ self.lookahead = cfg["lookahead"]
49
+ self.priority_weights = cfg["priority_weights"]
50
+ self.reset()
51
+
52
+ def reset(self) -> Dict[str, Any]:
53
+ if self.seed is not None:
54
+ random.seed(self.seed)
55
+ self.stacks: List[List[Container]] = [[] for _ in range(self.n_stacks)]
56
+ self.rehandle_count = 0
57
+ self.step_count = 0
58
+ self.total_reward = 0.0
59
+ self.done = False
60
+ self.manifest: List[Container] = self._generate_manifest()
61
+ self.retrieval_queue: List[str] = self._generate_retrieval_queue()
62
+ self.retrieval_pointer = 0
63
+ self.current_idx = 0
64
+ return self._observe(last_reward=0.0)
65
+
66
+ def _generate_manifest(self) -> List[Container]:
67
+ containers = []
68
+ for i in range(self.n_containers):
69
+ priority = random.choices([1, 2, 3], weights=self.priority_weights)[0]
70
+ containers.append(Container(
71
+ id=f"C{i:03d}",
72
+ priority=priority,
73
+ weight=round(random.uniform(5.0, 30.0), 1)
74
+ ))
75
+ return containers
76
+
77
+ def _generate_retrieval_queue(self) -> List[str]:
78
+ ids_by_priority = {1: [], 2: [], 3: []}
79
+ for c in self.manifest:
80
+ ids_by_priority[c.priority].append(c.id)
81
+ for p in ids_by_priority:
82
+ random.shuffle(ids_by_priority[p])
83
+ queue = ids_by_priority[1] + ids_by_priority[2] + ids_by_priority[3]
84
+ return queue
85
+
86
+ def step(self, stack_index: int) -> Tuple[Dict[str, Any], float, bool, Dict[str, Any]]:
87
+ if self.done:
88
+ return self._observe(0.0), 0.0, True, {"error": "episode already done"}
89
+
90
+ if stack_index < 0 or stack_index >= self.n_stacks:
91
+ reward = -2.0
92
+ self.total_reward += reward
93
+ return self._observe(reward), reward, False, {"error": f"invalid stack_index {stack_index}, must be 0-{self.n_stacks-1}"}
94
+
95
+ if len(self.stacks[stack_index]) >= self.max_height:
96
+ reward = -2.0
97
+ self.total_reward += reward
98
+ return self._observe(reward), reward, False, {"error": f"stack {stack_index} is full (height {self.max_height})"}
99
+
100
+ current = self.manifest[self.current_idx]
101
+ self.stacks[stack_index].append(current)
102
+ placement_reward = self._placement_reward(stack_index, current)
103
+
104
+ self.current_idx += 1
105
+ self.step_count += 1
106
+
107
+ retrieval_cost = 0.0
108
+ retrievals_done = []
109
+ if self.step_count % self.retrieval_interval == 0:
110
+ cost, done_ids = self._trigger_retrieval()
111
+ retrieval_cost = cost
112
+ retrievals_done = done_ids
113
+
114
+ reward = placement_reward - retrieval_cost
115
+ self.total_reward += reward
116
+ self.done = (self.current_idx >= len(self.manifest))
117
+
118
+ return self._observe(reward), reward, self.done, {
119
+ "rehandles": self.rehandle_count,
120
+ "step": self.step_count,
121
+ "placement_reward": round(placement_reward, 4),
122
+ "retrieval_cost": round(retrieval_cost, 4),
123
+ "retrievals_done": retrievals_done,
124
+ }
125
+
126
+ def _placement_reward(self, stack_index: int, container: Container) -> float:
127
+ # stack_depth = zero-based index of the just-placed container
128
+ stack_depth = len(self.stacks[stack_index]) - 1
129
+ accessibility = (self.max_height - stack_depth) / self.max_height
130
+ priority_weight = (4 - container.priority) / 3.0 # priority 11.0, 20.67, 30.33
131
+
132
+ base = 0.3 * accessibility * priority_weight
133
+
134
+ # Bonus: high-priority container placed near top (accessible for fast retrieval)
135
+ if container.priority == 1 and stack_depth <= 1:
136
+ base += 0.15
137
+
138
+ # Penalty: placing lower-priority on top of higher-priority container (causes future rehandles)
139
+ if stack_depth > 0:
140
+ top_container = self.stacks[stack_index][-2] # container directly below
141
+ if container.priority > top_container.priority:
142
+ base -= 0.2 * (container.priority - top_container.priority) / 2.0
143
+
144
+ return round(base, 4)
145
+
146
+ def _trigger_retrieval(self) -> Tuple[float, List[str]]:
147
+ total_cost = 0.0
148
+ done_ids = []
149
+ for _ in range(2):
150
+ if self.retrieval_pointer >= len(self.retrieval_queue):
151
+ break
152
+ target_id = self.retrieval_queue[self.retrieval_pointer]
153
+ self.retrieval_pointer += 1
154
+ cost = self._retrieve(target_id)
155
+ total_cost += cost
156
+ done_ids.append(target_id)
157
+ return total_cost, done_ids
158
+
159
+ def _retrieve(self, target_id: str) -> float:
160
+ for stack in self.stacks:
161
+ for i, c in enumerate(stack):
162
+ if c.id == target_id:
163
+ rehandles = len(stack) - 1 - i # containers above target
164
+ self.rehandle_count += rehandles
165
+ stack.pop(i)
166
+ return round(rehandles * 0.4, 4)
167
+ return 0.0 # container not yet in yard — no penalty
168
+
169
+ def _get_upcoming_retrievals(self) -> List[str]:
170
+ start = self.retrieval_pointer
171
+ end = min(start + self.lookahead, len(self.retrieval_queue))
172
+ return self.retrieval_queue[start:end]
173
+
174
+ def _observe(self, last_reward: float = 0.0) -> Dict[str, Any]:
175
+ stack_states = []
176
+ for s in self.stacks:
177
+ stack_states.append([{"id": c.id, "priority": c.priority} for c in s])
178
+
179
+ current = None
180
+ if self.current_idx < len(self.manifest):
181
+ c = self.manifest[self.current_idx]
182
+ current = {"id": c.id, "priority": c.priority, "weight": c.weight}
183
+
184
+ return {
185
+ "stack_states": stack_states,
186
+ "current_container": current,
187
+ "upcoming_retrievals": self._get_upcoming_retrievals(),
188
+ "rehandle_count": self.rehandle_count,
189
+ "step": self.step_count,
190
+ "containers_remaining": len(self.manifest) - self.current_idx,
191
+ "n_stacks": self.n_stacks,
192
+ "max_height": self.max_height,
193
+ "difficulty": self.difficulty,
194
+ "last_reward": last_reward,
195
+ "done": self.done,
196
+ }
197
+
198
+ def get_state(self) -> Dict[str, Any]:
199
+ obs = self._observe()
200
+ obs["score"] = self.score()
201
+ obs["total_reward"] = round(self.total_reward, 4)
202
+ return obs
203
+
204
+ def score(self) -> float:
205
+ """Normalized score in [0.0, 1.0]. Based on actual retrievals attempted."""
206
+ n_retrieved = self.retrieval_pointer # only count retrievals that actually happened
207
+ worst_case = n_retrieved * (self.max_height - 1)
208
+ if worst_case == 0:
209
+ return 1.0
210
+ score = max(0.0, 1.0 - self.rehandle_count / worst_case)
211
+ return round(min(score, 1.0), 4)
server/models.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import List, Optional, Dict, Any
3
+
4
+ class ContainerInfo(BaseModel):
5
+ id: str
6
+ priority: int = Field(..., ge=1, le=3)
7
+ weight: float
8
+
9
+ class StackEntry(BaseModel):
10
+ id: str
11
+ priority: int
12
+
13
+ class ContainerAction(BaseModel):
14
+ stack_index: int = Field(..., description="Which stack (0-indexed) to place the current container into")
15
+
16
+ class ContainerObservation(BaseModel):
17
+ stack_states: List[List[Dict[str, Any]]]
18
+ current_container: Optional[Dict[str, Any]]
19
+ upcoming_retrievals: List[str]
20
+ rehandle_count: int
21
+ step: int
22
+ containers_remaining: int
23
+ n_stacks: int
24
+ max_height: int
25
+ difficulty: str
26
+ last_reward: float
27
+ done: bool
28
+
29
+ class ContainerState(BaseModel):
30
+ stack_states: List[List[Dict[str, Any]]]
31
+ rehandle_count: int
32
+ step: int
33
+ score: float
34
+ difficulty: str
35
+ done: bool
36
+ total_reward: float
server/server.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import uuid
3
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
+ from server.environment import ContainerYardEnv
5
+ from server.models import ContainerAction
6
+
7
+ app = FastAPI(title="Container Port OpenEnv", version="0.1.0")
8
+
9
+ sessions: dict = {}
10
+
11
+ @app.get("/ping")
12
+ def ping():
13
+ return {"status": "ok", "env": "container-port-env"}
14
+
15
+ @app.get("/health")
16
+ def health():
17
+ return {
18
+ "status": "healthy",
19
+ "active_sessions": len(sessions),
20
+ "difficulties": ["easy", "medium", "hard"],
21
+ }
22
+
23
+ @app.websocket("/ws")
24
+ async def websocket_endpoint(websocket: WebSocket):
25
+ await websocket.accept()
26
+ session_id = str(uuid.uuid4())
27
+ sessions[session_id] = ContainerYardEnv(difficulty="medium")
28
+
29
+ try:
30
+ while True:
31
+ raw = await websocket.receive_text()
32
+ msg = json.loads(raw)
33
+ msg_type = msg.get("type")
34
+ env = sessions[session_id]
35
+
36
+ if msg_type == "reset":
37
+ difficulty = msg.get("difficulty", "medium")
38
+ if difficulty not in ["easy", "medium", "hard"]:
39
+ difficulty = "medium"
40
+ sessions[session_id] = ContainerYardEnv(difficulty=difficulty)
41
+ env = sessions[session_id]
42
+ obs = env.reset()
43
+ await websocket.send_text(json.dumps({
44
+ "type": "reset",
45
+ "observation": obs,
46
+ "reward": 0.0,
47
+ "done": False,
48
+ "session_id": session_id,
49
+ }))
50
+
51
+ elif msg_type == "step":
52
+ try:
53
+ action = ContainerAction(**msg["action"])
54
+ obs, reward, done, info = env.step(action.stack_index)
55
+ await websocket.send_text(json.dumps({
56
+ "type": "step",
57
+ "observation": obs,
58
+ "reward": reward,
59
+ "done": done,
60
+ "info": info,
61
+ }))
62
+ except Exception as e:
63
+ await websocket.send_text(json.dumps({
64
+ "type": "error",
65
+ "message": str(e),
66
+ }))
67
+
68
+ elif msg_type == "state":
69
+ state = env.get_state()
70
+ await websocket.send_text(json.dumps({
71
+ "type": "state",
72
+ "state": state,
73
+ }))
74
+
75
+ else:
76
+ await websocket.send_text(json.dumps({
77
+ "type": "error",
78
+ "message": f"Unknown message type: {msg_type}",
79
+ }))
80
+
81
+ except WebSocketDisconnect:
82
+ pass
83
+ finally:
84
+ sessions.pop(session_id, None)
tests/__pycache__/test_env.cpython-313-pytest-9.0.2.pyc.10200 ADDED
Binary file (18.9 kB). View file
 
tests/__pycache__/test_env.cpython-313-pytest-9.0.2.pyc.16772 ADDED
Binary file (18.9 kB). View file
 
tests/test_env.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from server.environment import ContainerYardEnv, DIFFICULTY_CONFIG
3
+
4
+ @pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
5
+ def test_reset_returns_valid_obs(difficulty):
6
+ env = ContainerYardEnv(difficulty=difficulty, seed=42)
7
+ obs = env.reset()
8
+ cfg = DIFFICULTY_CONFIG[difficulty]
9
+ assert len(obs["stack_states"]) == cfg["n_stacks"]
10
+ assert obs["current_container"] is not None
11
+ assert obs["step"] == 0
12
+ assert obs["rehandle_count"] == 0
13
+ assert obs["difficulty"] == difficulty
14
+ assert obs["done"] == False
15
+
16
+ @pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
17
+ def test_step_valid_action(difficulty):
18
+ env = ContainerYardEnv(difficulty=difficulty, seed=42)
19
+ env.reset()
20
+ obs, reward, done, info = env.step(0)
21
+ assert isinstance(reward, float)
22
+ assert obs["step"] == 1
23
+ assert len(obs["stack_states"][0]) == 1
24
+ assert "rehandles" in info
25
+
26
+ @pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
27
+ def test_step_invalid_stack_index(difficulty):
28
+ env = ContainerYardEnv(difficulty=difficulty, seed=42)
29
+ env.reset()
30
+ obs, reward, done, info = env.step(999)
31
+ assert reward == -2.0
32
+ assert "error" in info
33
+ assert done == False
34
+
35
+ @pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
36
+ def test_full_episode_completes(difficulty):
37
+ env = ContainerYardEnv(difficulty=difficulty, seed=42)
38
+ env.reset()
39
+ done = False
40
+ steps = 0
41
+ cfg = DIFFICULTY_CONFIG[difficulty]
42
+ n_stacks = cfg["n_stacks"]
43
+ max_height = cfg["max_height"]
44
+ while not done:
45
+ stacks = env._observe()["stack_states"]
46
+ chosen = 0
47
+ for i in range(n_stacks):
48
+ if len(stacks[i]) < max_height:
49
+ chosen = i
50
+ break
51
+ _, _, done, _ = env.step(chosen)
52
+ steps += 1
53
+ assert steps < 1000, "Episode did not complete in time"
54
+ assert done
55
+
56
+ @pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
57
+ def test_score_in_range(difficulty):
58
+ env = ContainerYardEnv(difficulty=difficulty, seed=42)
59
+ env.reset()
60
+ done = False
61
+ cfg = DIFFICULTY_CONFIG[difficulty]
62
+ n_stacks = cfg["n_stacks"]
63
+ max_height = cfg["max_height"]
64
+ while not done:
65
+ stacks = env._observe()["stack_states"]
66
+ chosen = 0
67
+ for i in range(n_stacks):
68
+ if len(stacks[i]) < max_height:
69
+ chosen = i
70
+ break
71
+ _, _, done, _ = env.step(chosen)
72
+ score = env.score()
73
+ assert 0.0 <= score <= 1.0
74
+
75
+ def test_lookahead_visibility():
76
+ easy_env = ContainerYardEnv(difficulty="easy", seed=42)
77
+ hard_env = ContainerYardEnv(difficulty="hard", seed=42)
78
+ easy_obs = easy_env.reset()
79
+ hard_obs = hard_env.reset()
80
+ assert len(easy_obs["upcoming_retrievals"]) > len(hard_obs["upcoming_retrievals"])
81
+ assert len(hard_obs["upcoming_retrievals"]) == 0
82
+
83
+ def test_reward_is_dense():
84
+ env = ContainerYardEnv(difficulty="medium", seed=42)
85
+ env.reset()
86
+ rewards = []
87
+ done = False
88
+ step = 0
89
+ while not done and step < 20:
90
+ stacks = env._observe()["stack_states"]
91
+ chosen = step % 8
92
+ if len(stacks[chosen]) >= 5:
93
+ chosen = 0
94
+ _, r, done, _ = env.step(chosen)
95
+ rewards.append(r)
96
+ step += 1
97
+ nonzero = sum(1 for r in rewards if abs(r) > 1e-6)
98
+ assert nonzero >= len(rewards) * 0.5, f"Too many zero rewards: {rewards}"
99
+
100
+ def test_no_double_retrieval():
101
+ """Retrieval pointer advances correctly — no container retrieved twice."""
102
+ env = ContainerYardEnv(difficulty="easy", seed=42)
103
+ env.reset()
104
+ seen_ids = set()
105
+ for _ in range(env.n_containers):
106
+ if env.done:
107
+ break
108
+ env.step(0 if len(env.stacks[0]) < env.max_height else 1)
109
+ # retrieval_pointer should be <= queue length
110
+ assert env.retrieval_pointer <= len(env.retrieval_queue)