Kolaps27 commited on
Commit
0e5f237
·
1 Parent(s): 69c8431

chore: final code refactoring for portability and robustness

Browse files
Files changed (4) hide show
  1. baseline.py +6 -6
  2. env.py +6 -32
  3. inference.py +2 -3
  4. openenv_core.py +22 -0
baseline.py CHANGED
@@ -1,8 +1,7 @@
1
  import os
2
  import random
3
  import time
4
- from typing import Tuple
5
- from openai import OpenAI
6
  from env import UIEnv, Action, Observation
7
 
8
  VALID_ACTIONS = [
@@ -49,7 +48,7 @@ def heuristic_policy(obs: Observation) -> Action:
49
  return Action(type="noop")
50
 
51
 
52
- def llm_policy(client: OpenAI, obs: Observation, model_name: str) -> Action:
53
  state_desc = (
54
  f"Device: {obs.device}\n"
55
  f"Button Size: {obs.layout.button_size:.2f}\n"
@@ -112,7 +111,7 @@ def llm_policy(client: OpenAI, obs: Observation, model_name: str) -> Action:
112
  return Action(type="noop")
113
 
114
 
115
- def agent_policy(client: OpenAI, obs: Observation, model_name: str) -> Action:
116
  heuristic_action = heuristic_policy(obs)
117
  if heuristic_action.type != "noop":
118
  return heuristic_action
@@ -120,7 +119,7 @@ def agent_policy(client: OpenAI, obs: Observation, model_name: str) -> Action:
120
  return llm_policy(client, obs, model_name)
121
 
122
 
123
- def run_episode(env: UIEnv, client: OpenAI, model_name: str) -> Tuple[float, bool]:
124
  obs = env.reset()
125
  total_reward = 0.0
126
  done = False
@@ -144,7 +143,7 @@ def run_episode(env: UIEnv, client: OpenAI, model_name: str) -> Tuple[float, boo
144
  return total_reward, completed
145
 
146
 
147
- def evaluate_task(task: str, client: OpenAI, model_name: str, n_episodes: int = 1) -> Tuple[float, float, float]:
148
  total_rewards = 0.0
149
  completions = 0
150
 
@@ -171,6 +170,7 @@ def main():
171
  api_key = os.getenv("API_KEY") or os.getenv("HF_TOKEN")
172
 
173
  if api_key:
 
174
  print("Using Proxy / HF Router API...")
175
  client = OpenAI(
176
  base_url=base_url,
 
1
  import os
2
  import random
3
  import time
4
+ from typing import Any, Tuple
 
5
  from env import UIEnv, Action, Observation
6
 
7
  VALID_ACTIONS = [
 
48
  return Action(type="noop")
49
 
50
 
51
+ def llm_policy(client: Any, obs: Observation, model_name: str) -> Action:
52
  state_desc = (
53
  f"Device: {obs.device}\n"
54
  f"Button Size: {obs.layout.button_size:.2f}\n"
 
111
  return Action(type="noop")
112
 
113
 
114
+ def agent_policy(client: Any, obs: Observation, model_name: str) -> Action:
115
  heuristic_action = heuristic_policy(obs)
116
  if heuristic_action.type != "noop":
117
  return heuristic_action
 
119
  return llm_policy(client, obs, model_name)
120
 
121
 
122
+ def run_episode(env: UIEnv, client: Any, model_name: str) -> Tuple[float, bool]:
123
  obs = env.reset()
124
  total_reward = 0.0
125
  done = False
 
143
  return total_reward, completed
144
 
145
 
146
+ def evaluate_task(task: str, client: Any, model_name: str, n_episodes: int = 1) -> Tuple[float, float, float]:
147
  total_rewards = 0.0
148
  completions = 0
149
 
 
170
  api_key = os.getenv("API_KEY") or os.getenv("HF_TOKEN")
171
 
172
  if api_key:
173
+ from openai import OpenAI
174
  print("Using Proxy / HF Router API...")
175
  client = OpenAI(
176
  base_url=base_url,
env.py CHANGED
@@ -16,12 +16,15 @@ from pydantic import BaseModel, Field, model_validator
16
  EPS = 1e-6
17
 
18
  def safe_grader(fn):
19
- def wrapper(x):
20
  try:
21
  if x is None:
22
- val = 0.5
23
  else:
24
- val = fn(x)
 
 
 
25
 
26
  if not isinstance(val, (int, float)):
27
  val = 0.5
@@ -179,35 +182,6 @@ class UIEnv:
179
  print(f"GRADER FAIL [{t.name}]: {e}")
180
  print(f"VALID TASKS: {valid_tasks}")
181
 
182
- inputs = [None, {}, "test", 0, 1]
183
-
184
- print("=== VALIDATOR SIMULATION ===")
185
-
186
- valid = 0
187
-
188
- for t in self.tasks:
189
- task_valid = True
190
-
191
- for inp in inputs:
192
- try:
193
- val = t.grader(inp)
194
- print(t.name, inp, val)
195
-
196
- if not isinstance(val, (int, float)):
197
- task_valid = False
198
-
199
- if not (0 < val < 1):
200
- task_valid = False
201
-
202
- except Exception as e:
203
- print("CRASH:", t.name, inp, e)
204
- task_valid = False
205
-
206
- if task_valid:
207
- valid += 1
208
-
209
- print("VALID TASKS:", valid)
210
-
211
  self._layout: Layout = Layout()
212
  self._device: Literal["mobile", "desktop"] = "desktop"
213
  self._progress: float = 0.0
 
16
  EPS = 1e-6
17
 
18
  def safe_grader(fn):
19
+ def wrapper(x=None):
20
  try:
21
  if x is None:
22
+ val = fn()
23
  else:
24
+ try:
25
+ val = fn(x)
26
+ except TypeError:
27
+ val = fn()
28
 
29
  if not isinstance(val, (int, float)):
30
  val = 0.5
 
182
  print(f"GRADER FAIL [{t.name}]: {e}")
183
  print(f"VALID TASKS: {valid_tasks}")
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  self._layout: Layout = Layout()
186
  self._device: Literal["mobile", "desktop"] = "desktop"
187
  self._progress: float = 0.0
inference.py CHANGED
@@ -1,9 +1,7 @@
1
  import os
2
- import json
3
  import argparse
4
  from typing import List, Optional
5
  from env import UIEnv, Observation, Action, clamp_score
6
- from openai import OpenAI
7
 
8
  # Required Environment Variables
9
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
@@ -42,6 +40,7 @@ def run_inference(task_id: str = "easy") -> None:
42
  # 2. Setup Client
43
  client = None
44
  if API_KEY:
 
45
  client = OpenAI(
46
  base_url=API_BASE_URL,
47
  api_key=API_KEY
@@ -97,4 +96,4 @@ if __name__ == "__main__":
97
  parser.add_argument("--task", type=str, default=default_task, help="Task difficulty (easy, medium, hard)")
98
  args = parser.parse_args()
99
 
100
- run_inference(task_id=args.task)
 
1
  import os
 
2
  import argparse
3
  from typing import List, Optional
4
  from env import UIEnv, Observation, Action, clamp_score
 
5
 
6
  # Required Environment Variables
7
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
 
40
  # 2. Setup Client
41
  client = None
42
  if API_KEY:
43
+ from openai import OpenAI
44
  client = OpenAI(
45
  base_url=API_BASE_URL,
46
  api_key=API_KEY
 
96
  parser.add_argument("--task", type=str, default=default_task, help="Task difficulty (easy, medium, hard)")
97
  args = parser.parse_args()
98
 
99
+ run_inference(task_id=args.task)
openenv_core.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from fastapi import FastAPI
4
+
5
+
6
+ def create_app(env_cls, action_model, observation_model) -> FastAPI:
7
+ """Minimal local compatibility layer for this project.
8
+
9
+ The upstream package offers a richer OpenEnv integration, but this repo
10
+ only relies on receiving a FastAPI app instance to extend with routes.
11
+ """
12
+ app = FastAPI()
13
+ app.state.env_cls = env_cls
14
+ app.state.action_model = action_model
15
+ app.state.observation_model = observation_model
16
+ return app
17
+
18
+
19
+ def validate() -> None:
20
+ from validate_local import main
21
+
22
+ main()