Draken1606 commited on
Commit
6f90d54
·
1 Parent(s): 16229c6

fix: ensure task scores are strictly between 0 and 1 (exclusive)

Browse files

- Update score() method in server/environment.py to return values in (0.01, 0.99) range
- Change default score from 0.0 to 0.5 in models.py
- Clean up inference.py environment variable handling per hackathon guidelines
- Initialize OpenAI client at module level with proper error checking

Files changed (5) hide show
  1. README.md +2 -0
  2. inference.py +9 -10
  3. models.py +1 -1
  4. server/app.py +166 -0
  5. server/environment.py +4 -4
README.md CHANGED
@@ -37,6 +37,8 @@ uvicorn server.app:app --host 0.0.0.0 --port 7860
37
 
38
  Web UI: `http://127.0.0.1:7860/web`
39
 
 
 
40
  For manual stateful checks, use the web endpoints:
41
 
42
  ```bash
 
37
 
38
  Web UI: `http://127.0.0.1:7860/web`
39
 
40
+ Interactive dashboard with difficulty dropdown: `http://127.0.0.1:7860/dashboard`
41
+
42
  For manual stateful checks, use the web endpoints:
43
 
44
  ```bash
inference.py CHANGED
@@ -43,16 +43,16 @@ def _load_dotenv() -> None:
43
 
44
  _load_dotenv()
45
 
 
46
  HF_TOKEN = os.getenv('HF_TOKEN')
47
  API_BASE_URL = os.getenv('API_BASE_URL', 'https://api.openai.com/v1')
48
  MODEL_NAME = os.getenv('MODEL_NAME', 'meta-llama/Llama-3.1-8B-Instruct')
49
- LOCAL_IMAGE_NAME = os.getenv('LOCAL_IMAGE_NAME')
50
- OPENAI_API_KEY = HF_TOKENAPI_KEY = os.getenv('API_KEY')
51
- API_KEY = HF_TOKEN
52
- AUTH_TOKEN = HF_TOKEN or API_KEY or OPENAI_API_KEY
53
 
54
- if AUTH_TOKEN is None:
55
- raise ValueError('OPENAI_API_KEY (or API_KEY/HF_TOKEN) environment variable is required')
 
 
 
56
 
57
  ENV_URL = os.getenv('ENV_URL', 'http://localhost:7860')
58
  TASK_NAME = 'container-stacking'
@@ -188,7 +188,7 @@ async def run_episode(url: str, difficulty: str = 'medium', use_llm: bool = Fals
188
  if not ws_url.endswith('/ws'):
189
  ws_url = ws_url.rstrip('/') + '/ws'
190
 
191
- client = OpenAI(base_url=API_BASE_URL, api_key=AUTH_TOKEN) if use_llm else None
192
  model_label = MODEL_NAME if use_llm else 'greedy'
193
 
194
  log_start(task=f'{TASK_NAME}-{difficulty}', env=BENCHMARK, model=model_label)
@@ -209,7 +209,7 @@ async def run_episode(url: str, difficulty: str = 'medium', use_llm: bool = Fals
209
  if obs.get('done', False):
210
  break
211
 
212
- action_idx = llm_decide(obs, client) if use_llm else greedy_decide(obs)
213
 
214
  await ws.send(json.dumps({'type': 'step', 'data': {'stack_index': action_idx}}))
215
  resp = json.loads(await ws.recv())
@@ -229,8 +229,7 @@ async def run_episode(url: str, difficulty: str = 'medium', use_llm: bool = Fals
229
  await ws.send(json.dumps({'type': 'state'}))
230
  state_resp = json.loads(await ws.recv())
231
  state = state_resp.get('data', {})
232
- score = float(state.get('score', obs.get('score', 0.0)))
233
- score = min(max(score, 0.0), 1.0)
234
 
235
  success = score >= SUCCESS_SCORE_THRESHOLD
236
 
 
43
 
44
  _load_dotenv()
45
 
46
+ # Required environment variables
47
  HF_TOKEN = os.getenv('HF_TOKEN')
48
  API_BASE_URL = os.getenv('API_BASE_URL', 'https://api.openai.com/v1')
49
  MODEL_NAME = os.getenv('MODEL_NAME', 'meta-llama/Llama-3.1-8B-Instruct')
 
 
 
 
50
 
51
+ if HF_TOKEN is None:
52
+ raise ValueError('HF_TOKEN environment variable is required')
53
+
54
+ # Initialize OpenAI client
55
+ client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
56
 
57
  ENV_URL = os.getenv('ENV_URL', 'http://localhost:7860')
58
  TASK_NAME = 'container-stacking'
 
188
  if not ws_url.endswith('/ws'):
189
  ws_url = ws_url.rstrip('/') + '/ws'
190
 
191
+ llm_client = client if use_llm else None
192
  model_label = MODEL_NAME if use_llm else 'greedy'
193
 
194
  log_start(task=f'{TASK_NAME}-{difficulty}', env=BENCHMARK, model=model_label)
 
209
  if obs.get('done', False):
210
  break
211
 
212
+ action_idx = llm_decide(obs, llm_client) if use_llm else greedy_decide(obs)
213
 
214
  await ws.send(json.dumps({'type': 'step', 'data': {'stack_index': action_idx}}))
215
  resp = json.loads(await ws.recv())
 
229
  await ws.send(json.dumps({'type': 'state'}))
230
  state_resp = json.loads(await ws.recv())
231
  state = state_resp.get('data', {})
232
+ score = float(state.get('score', obs.get('score', 0.5)))
 
233
 
234
  success = score >= SUCCESS_SCORE_THRESHOLD
235
 
models.py CHANGED
@@ -36,5 +36,5 @@ class ContainerObservation(Observation):
36
  max_height: int = Field(0)
37
  difficulty: str = Field("medium")
38
  last_reward: float = Field(0.0)
39
- score: float = Field(0.0, description="Normalized score 0.0-1.0")
40
  done: bool = Field(False)
 
36
  max_height: int = Field(0)
37
  difficulty: str = Field("medium")
38
  last_reward: float = Field(0.0)
39
+ score: float = Field(0.5, description="Normalized score (0.0, 1.0)")
40
  done: bool = Field(False)
server/app.py CHANGED
@@ -15,6 +15,7 @@ if str(PROJECT_ROOT) not in sys.path:
15
  os.environ.setdefault("ENABLE_WEB_INTERFACE", "true")
16
 
17
  from openenv.core.env_server import create_web_interface_app
 
18
  import uvicorn
19
 
20
  from models import ContainerAction, ContainerObservation
@@ -28,6 +29,171 @@ app = create_web_interface_app(
28
  )
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def main() -> None:
32
  uvicorn.run("server.app:app", host="0.0.0.0", port=7860)
33
 
 
15
  os.environ.setdefault("ENABLE_WEB_INTERFACE", "true")
16
 
17
  from openenv.core.env_server import create_web_interface_app
18
+ from fastapi.responses import HTMLResponse
19
  import uvicorn
20
 
21
  from models import ContainerAction, ContainerObservation
 
29
  )
30
 
31
 
32
+ @app.get("/dashboard", response_class=HTMLResponse)
33
+ def dashboard() -> str:
34
+ return """
35
+ <!doctype html>
36
+ <html lang="en">
37
+ <head>
38
+ <meta charset="utf-8" />
39
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
40
+ <title>Container Port Dashboard</title>
41
+ <style>
42
+ :root {
43
+ --bg: #f4f5ef;
44
+ --card: #ffffff;
45
+ --ink: #18211f;
46
+ --accent: #0b6e4f;
47
+ --muted: #5f6a66;
48
+ --line: #d7ddd7;
49
+ }
50
+ * { box-sizing: border-box; }
51
+ body {
52
+ margin: 0;
53
+ padding: 24px;
54
+ background: radial-gradient(circle at 80% 20%, #dbeee5 0, var(--bg) 45%);
55
+ color: var(--ink);
56
+ font-family: "Segoe UI", Tahoma, Geneva, Verdana, sans-serif;
57
+ }
58
+ .wrap { max-width: 980px; margin: 0 auto; }
59
+ h1 { margin: 0 0 8px; }
60
+ p { margin: 0 0 16px; color: var(--muted); }
61
+ .panel {
62
+ background: var(--card);
63
+ border: 1px solid var(--line);
64
+ border-radius: 14px;
65
+ padding: 16px;
66
+ margin-bottom: 16px;
67
+ }
68
+ .row {
69
+ display: flex;
70
+ gap: 10px;
71
+ flex-wrap: wrap;
72
+ align-items: center;
73
+ }
74
+ select, input, button {
75
+ border: 1px solid var(--line);
76
+ border-radius: 10px;
77
+ padding: 10px 12px;
78
+ font-size: 14px;
79
+ background: #fff;
80
+ color: var(--ink);
81
+ }
82
+ button {
83
+ cursor: pointer;
84
+ background: var(--accent);
85
+ color: #fff;
86
+ border-color: var(--accent);
87
+ font-weight: 600;
88
+ }
89
+ button.secondary {
90
+ background: #fff;
91
+ color: var(--ink);
92
+ border-color: var(--line);
93
+ font-weight: 500;
94
+ }
95
+ pre {
96
+ margin: 0;
97
+ background: #0e1a17;
98
+ color: #eaf8f1;
99
+ border-radius: 12px;
100
+ padding: 14px;
101
+ overflow: auto;
102
+ min-height: 220px;
103
+ font-size: 12px;
104
+ line-height: 1.35;
105
+ }
106
+ .hint { font-size: 12px; color: var(--muted); margin-top: 8px; }
107
+ </style>
108
+ </head>
109
+ <body>
110
+ <div class="wrap">
111
+ <h1>Container Port Dashboard</h1>
112
+ <p>Pick a difficulty and step the environment manually.</p>
113
+
114
+ <div class="panel">
115
+ <div class="row">
116
+ <label for="difficulty">Difficulty</label>
117
+ <select id="difficulty">
118
+ <option value="easy">Easy</option>
119
+ <option value="medium" selected>Medium</option>
120
+ <option value="hard">Hard</option>
121
+ </select>
122
+ <button id="resetBtn">Reset</button>
123
+ <button id="stateBtn" class="secondary">State</button>
124
+ </div>
125
+ <div class="hint">Reset calls <code>/web/reset</code> with the selected mode.</div>
126
+ </div>
127
+
128
+ <div class="panel">
129
+ <div class="row">
130
+ <label for="stack">stack_index</label>
131
+ <input id="stack" type="number" min="0" step="1" value="0" />
132
+ <button id="stepBtn">Step</button>
133
+ </div>
134
+ <div class="hint">Step calls <code>/web/step</code> with action <code>{"stack_index": n}</code>.</div>
135
+ </div>
136
+
137
+ <div class="panel">
138
+ <pre id="out">Click Reset to start an episode.</pre>
139
+ </div>
140
+ </div>
141
+
142
+ <script>
143
+ const out = document.getElementById('out');
144
+ const difficulty = document.getElementById('difficulty');
145
+ const stack = document.getElementById('stack');
146
+
147
+ function show(data) {
148
+ out.textContent = JSON.stringify(data, null, 2);
149
+ }
150
+
151
+ async function postJson(url, payload) {
152
+ const res = await fetch(url, {
153
+ method: 'POST',
154
+ headers: { 'Content-Type': 'application/json' },
155
+ body: JSON.stringify(payload),
156
+ });
157
+ const data = await res.json();
158
+ show(data);
159
+ }
160
+
161
+ async function getJson(url) {
162
+ const res = await fetch(url);
163
+ const data = await res.json();
164
+ show(data);
165
+ }
166
+
167
+ document.getElementById('resetBtn').addEventListener('click', async () => {
168
+ try {
169
+ await postJson('/web/reset', { difficulty: difficulty.value });
170
+ } catch (err) {
171
+ show({ error: String(err) });
172
+ }
173
+ });
174
+
175
+ document.getElementById('stepBtn').addEventListener('click', async () => {
176
+ const idx = Number(stack.value);
177
+ try {
178
+ await postJson('/web/step', { action: { stack_index: idx } });
179
+ } catch (err) {
180
+ show({ error: String(err) });
181
+ }
182
+ });
183
+
184
+ document.getElementById('stateBtn').addEventListener('click', async () => {
185
+ try {
186
+ await getJson('/web/state');
187
+ } catch (err) {
188
+ show({ error: String(err) });
189
+ }
190
+ });
191
+ </script>
192
+ </body>
193
+ </html>
194
+ """
195
+
196
+
197
  def main() -> None:
198
  uvicorn.run("server.app:app", host="0.0.0.0", port=7860)
199
 
server/environment.py CHANGED
@@ -229,13 +229,13 @@ class ContainerYardEnvironment(Environment):
229
  )
230
 
231
  def score(self) -> float:
232
- """Normalized score in [0.0, 1.0]. Based on actual retrievals attempted."""
233
  n_retrieved = self.retrieval_pointer # only count retrievals that actually happened
234
  worst_case = n_retrieved * (self.max_height - 1)
235
  if worst_case == 0:
236
- return 1.0
237
- score = max(0.0, 1.0 - self.rehandle_count / worst_case)
238
- return round(min(score, 1.0), 4)
239
 
240
  def get_state(self) -> dict[str, Any]:
241
  return self._observe().model_dump()
 
229
  )
230
 
231
  def score(self) -> float:
232
+ """Normalized score in (0.0, 1.0). Based on actual retrievals attempted."""
233
  n_retrieved = self.retrieval_pointer # only count retrievals that actually happened
234
  worst_case = n_retrieved * (self.max_height - 1)
235
  if worst_case == 0:
236
+ return 0.99
237
+ score = max(0.01, min(1.0 - self.rehandle_count / worst_case, 0.99))
238
+ return round(score, 4)
239
 
240
  def get_state(self) -> dict[str, Any]:
241
  return self._observe().model_dump()