File size: 9,504 Bytes
7a316cd
11f9523
7a316cd
0eb4f6f
11f9523
0eb4f6f
 
 
 
11f9523
 
 
 
0eb4f6f
16229c6
0eb4f6f
11f9523
 
7a316cd
0eb4f6f
7a316cd
f776f88
11f9523
 
0eb4f6f
7a316cd
11f9523
 
7a316cd
 
 
0eb4f6f
7a316cd
0eb4f6f
 
7a316cd
0eb4f6f
7a316cd
0eb4f6f
 
 
 
 
7a316cd
0eb4f6f
6f90d54
7a316cd
16229c6
 
 
6f90d54
 
 
 
 
0eb4f6f
7a316cd
 
 
 
 
0eb4f6f
11f9523
f776f88
 
 
 
 
 
 
 
 
 
 
0eb4f6f
7a316cd
11f9523
 
0eb4f6f
7a316cd
0eb4f6f
 
16229c6
0eb4f6f
 
 
 
a8ffe4c
7a316cd
0eb4f6f
a8ffe4c
0eb4f6f
 
 
 
11f9523
7a316cd
 
 
 
11f9523
 
 
 
7a316cd
 
11f9523
 
 
 
 
 
 
 
 
 
7a316cd
0eb4f6f
 
 
11f9523
 
7a316cd
11f9523
 
 
 
 
 
 
 
 
 
 
 
 
0eb4f6f
11f9523
 
0eb4f6f
7a316cd
 
 
 
 
 
11f9523
0eb4f6f
11f9523
 
7a316cd
11f9523
7a316cd
0eb4f6f
7a316cd
11f9523
 
 
 
7a316cd
 
 
 
0eb4f6f
11f9523
 
7a316cd
 
 
 
11f9523
 
 
0eb4f6f
11f9523
 
 
7a316cd
11f9523
7a316cd
 
 
 
11f9523
 
7a316cd
 
11f9523
7a316cd
 
11f9523
 
 
 
7a316cd
0eb4f6f
 
7a316cd
 
 
11f9523
6f90d54
7a316cd
0eb4f6f
7a316cd
0eb4f6f
 
 
e8cd840
0eb4f6f
 
 
 
7a316cd
0eb4f6f
7a316cd
 
0eb4f6f
 
7a316cd
0eb4f6f
 
6f90d54
0eb4f6f
7a316cd
0eb4f6f
7a316cd
 
f776f88
 
 
7a316cd
 
0eb4f6f
 
 
 
 
 
 
 
7a316cd
0eb4f6f
7a316cd
f776f88
0eb4f6f
 
 
 
7a316cd
0eb4f6f
 
a8ffe4c
 
0eb4f6f
 
 
 
 
7a316cd
11f9523
 
 
7a316cd
 
 
 
16229c6
11f9523
16229c6
11f9523
7a316cd
16229c6
11f9523
16229c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
#!/usr/bin/env python3
"""
Container Port OpenEnv - Baseline Inference Script
SST x Meta PyTorch OpenEnv Hackathon 2026

Stdout format (grader parses these exactly):
  [START] task=<task> env=container-port-env model=<model>
  [STEP]  step=<n> action=<stack_idx> reward=<0.00> done=<true|false> error=<msg|null>
  [END]   success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...>

Usage:
  python inference.py
  python inference.py --difficulty easy
  python inference.py --difficulty all
    python inference.py --no-llm
  python inference.py --url https://YOUR_USERNAME-container-port-env.hf.space
"""

import argparse
import asyncio
import json
import math
import os
import sys
from typing import List, Optional

from openai import OpenAI


def _load_dotenv() -> None:
    env_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.env')
    if os.path.exists(env_path):
        with open(env_path, encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith('#') or '=' not in line:
                    continue
                key, _, value = line.partition('=')
                key = key.strip()
                value = value.strip().strip('"').strip("'")
                if key and key not in os.environ:
                    os.environ[key] = value


_load_dotenv()
# Required environment variables
HF_TOKEN = os.getenv('HF_TOKEN')
API_BASE_URL = os.getenv('API_BASE_URL', 'https://api.openai.com/v1')
MODEL_NAME = os.getenv('MODEL_NAME', 'meta-llama/Llama-3.1-8B-Instruct')

if HF_TOKEN is None:
    raise ValueError('HF_TOKEN environment variable is required')

# Initialize OpenAI client
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)

ENV_URL = os.getenv('ENV_URL', 'http://localhost:7860')
TASK_NAME = 'container-stacking'
BENCHMARK = 'container-port-env'
MAX_STEPS = 200
SUCCESS_SCORE_THRESHOLD = 0.5


def _strict_unit_interval(value: object, fallback: float = 0.5) -> float:
    """Clamp to a strict (0, 1) range and guard non-finite values."""
    try:
        v = float(value)
    except (TypeError, ValueError):
        v = fallback
    if not math.isfinite(v):
        v = fallback
    return min(max(v, 0.01), 0.99)


def log_start(task: str, env: str, model: str) -> None:
    print(f'[START] task={task} env={env} model={model}', flush=True)


def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
    error_val = error if error else 'null'
    done_val = str(done).lower()
    print(
        f'[STEP]  step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}',
        flush=True,
    )


def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
    rewards_str = ','.join(f'{r:.2f}' for r in rewards)
    print(
        f'[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}',
        flush=True,
    )


def greedy_decide(obs: dict) -> int:
    stacks = obs['stack_states']
    current = obs.get('current_container')
    max_height = obs['max_height']
    upcoming = set(obs.get('upcoming_retrievals', []))

    if current is None:
        return 0

    cur_priority = current['priority']
    best_stack, best_score = -1, float('-inf')

    for i, stack in enumerate(stacks):
        depth = len(stack)
        if depth >= max_height:
            continue
        score = 0.0
        accessibility = (max_height - depth) / max_height
        score += accessibility * (4 - cur_priority)

        if depth > 0:
            top_p = stack[-1]['priority']
            if cur_priority > top_p:
                score -= 10.0 * (cur_priority - top_p)
            elif cur_priority < top_p:
                score += 3.0

        if current['id'] in upcoming:
            score += 5.0 * accessibility

        if depth > 0:
            score += 0.5

        if score > best_score:
            best_score = score
            best_stack = i

    if best_stack == -1:
        for i, stack in enumerate(stacks):
            if len(stack) < max_height:
                return i
    return max(best_stack, 0)


def llm_decide(obs: dict, client: OpenAI) -> int:
    stacks = obs['stack_states']
    current = obs.get('current_container')
    n_stacks = obs['n_stacks']
    max_height = obs['max_height']
    upcoming = obs.get('upcoming_retrievals', [])
    difficulty = obs.get('difficulty', 'medium')

    lines = []
    for i, stack in enumerate(stacks):
        if not stack:
            lines.append(f'  Stack {i}: EMPTY (0/{max_height})')
        else:
            contents = ', '.join(f"{c['id']}(p{c['priority']})" for c in stack)
            lines.append(
                f'  Stack {i}: [{contents}] depth={len(stack)}/{max_height},'
                f" top=priority-{stack[-1]['priority']}"
            )

    prompt = (
        'You are a container yard planner. Minimize rehandle operations.\n'
        'Priority 1=URGENT (retrieved first), 2=Normal, 3=Low.\n'
        'RULE: containers above the target at retrieval = rehandles (costly).\n\n'
        f'DIFFICULTY: {difficulty}\n'
        f"UPCOMING RETRIEVALS: {upcoming or 'Unknown (hard mode)'}\n\n"
        f"CONTAINER TO PLACE: id={current['id']}, priority={current['priority']}, "
        f"weight={current['weight']}kg\n\n"
        + 'STACKS (bottom->top):\n'
        + '\n'.join(lines)
        + '\n\n'
        + f'Reply ONLY with valid JSON: {{"stack_index": <int 0-{n_stacks - 1}>}}'
    )

    try:
        resp = client.chat.completions.create(
            model=MODEL_NAME,
            max_tokens=64,
            temperature=0.0,
            messages=[{'role': 'user', 'content': prompt}],
        )
        text = (resp.choices[0].message.content or '').strip()
        if '```' in text:
            text = text.split('```')[1]
            if text.startswith('json'):
                text = text[4:]
        decision = json.loads(text.strip())
        idx = int(decision['stack_index'])
        if 0 <= idx < n_stacks and len(obs['stack_states'][idx]) < max_height:
            return idx
    except Exception as exc:
        print(f'[DEBUG] LLM fallback: {exc}', file=sys.stderr, flush=True)

    return greedy_decide(obs)


async def run_episode(url: str, difficulty: str = 'medium', use_llm: bool = False) -> float:
    import websockets

    ws_url = url.replace('http://', 'ws://').replace('https://', 'wss://')
    if not ws_url.endswith('/ws'):
        ws_url = ws_url.rstrip('/') + '/ws'

    llm_client = client if use_llm else None
    model_label = MODEL_NAME if use_llm else 'greedy'

    log_start(task=f'{TASK_NAME}-{difficulty}', env=BENCHMARK, model=model_label)

    rewards: List[float] = []
    steps_taken = 0
    score = 0.5
    success = False

    try:
        async with websockets.connect(ws_url) as ws:
            await ws.send(json.dumps({'type': 'reset', 'data': {'difficulty': difficulty}}))
            resp = json.loads(await ws.recv())
            payload = resp.get('data', {})
            obs = payload.get('observation', payload)

            for step in range(1, MAX_STEPS + 1):
                if obs.get('done', False):
                    break

                action_idx = llm_decide(obs, llm_client) if use_llm else greedy_decide(obs)

                await ws.send(json.dumps({'type': 'step', 'data': {'stack_index': action_idx}}))
                resp = json.loads(await ws.recv())
                payload = resp.get('data', {})
                obs = payload.get('observation', payload)
                raw_reward = payload.get('reward', obs.get('last_reward', 0.0))
                # Normalize step reward to strictly (0, 1) as required by the grader.
                reward = _strict_unit_interval(raw_reward, fallback=0.5)
                done = payload.get('done', obs.get('done', False))
                error = payload.get('error', None)

                rewards.append(reward)
                steps_taken = step
                log_step(step=step, action=str(action_idx), reward=reward, done=done, error=error)

                if done:
                    break

            await ws.send(json.dumps({'type': 'state'}))
            state_resp = json.loads(await ws.recv())
            state = state_resp.get('data', {})
            score = _strict_unit_interval(state.get('score', obs.get('score', 0.5)), fallback=0.5)

        success = score >= SUCCESS_SCORE_THRESHOLD

    except Exception as exc:
        print(f'[DEBUG] Episode error: {exc}', file=sys.stderr, flush=True)

    finally:
        score = _strict_unit_interval(score, fallback=0.5)
        log_end(success=success, steps=steps_taken, score=score, rewards=rewards)

    return score


async def run_all(url: str, use_llm: bool = False) -> None:
    for diff in ['easy', 'medium', 'hard']:
        await run_episode(url, difficulty=diff, use_llm=use_llm)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Container Port Baseline Agent')
    parser.add_argument('--url', default=ENV_URL)
    parser.add_argument('--difficulty', default='all', choices=['easy', 'medium', 'hard', 'all'])
    parser.add_argument('--no-llm', action='store_true', help='Disable LLM agent and use greedy policy')
    args = parser.parse_args()
    use_llm = not args.no_llm

    if args.difficulty == 'all':
        asyncio.run(run_all(args.url, use_llm=use_llm))
    else:
        asyncio.run(run_episode(args.url, difficulty=args.difficulty, use_llm=use_llm))