Spaces:
Sleeping
Sleeping
chore: final code refactoring for portability and robustness
Browse files- baseline.py +6 -6
- env.py +6 -32
- inference.py +2 -3
- openenv_core.py +22 -0
baseline.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import random
|
| 3 |
import time
|
| 4 |
-
from typing import Tuple
|
| 5 |
-
from openai import OpenAI
|
| 6 |
from env import UIEnv, Action, Observation
|
| 7 |
|
| 8 |
VALID_ACTIONS = [
|
|
@@ -49,7 +48,7 @@ def heuristic_policy(obs: Observation) -> Action:
|
|
| 49 |
return Action(type="noop")
|
| 50 |
|
| 51 |
|
| 52 |
-
def llm_policy(client:
|
| 53 |
state_desc = (
|
| 54 |
f"Device: {obs.device}\n"
|
| 55 |
f"Button Size: {obs.layout.button_size:.2f}\n"
|
|
@@ -112,7 +111,7 @@ def llm_policy(client: OpenAI, obs: Observation, model_name: str) -> Action:
|
|
| 112 |
return Action(type="noop")
|
| 113 |
|
| 114 |
|
| 115 |
-
def agent_policy(client:
|
| 116 |
heuristic_action = heuristic_policy(obs)
|
| 117 |
if heuristic_action.type != "noop":
|
| 118 |
return heuristic_action
|
|
@@ -120,7 +119,7 @@ def agent_policy(client: OpenAI, obs: Observation, model_name: str) -> Action:
|
|
| 120 |
return llm_policy(client, obs, model_name)
|
| 121 |
|
| 122 |
|
| 123 |
-
def run_episode(env: UIEnv, client:
|
| 124 |
obs = env.reset()
|
| 125 |
total_reward = 0.0
|
| 126 |
done = False
|
|
@@ -144,7 +143,7 @@ def run_episode(env: UIEnv, client: OpenAI, model_name: str) -> Tuple[float, boo
|
|
| 144 |
return total_reward, completed
|
| 145 |
|
| 146 |
|
| 147 |
-
def evaluate_task(task: str, client:
|
| 148 |
total_rewards = 0.0
|
| 149 |
completions = 0
|
| 150 |
|
|
@@ -171,6 +170,7 @@ def main():
|
|
| 171 |
api_key = os.getenv("API_KEY") or os.getenv("HF_TOKEN")
|
| 172 |
|
| 173 |
if api_key:
|
|
|
|
| 174 |
print("Using Proxy / HF Router API...")
|
| 175 |
client = OpenAI(
|
| 176 |
base_url=base_url,
|
|
|
|
| 1 |
import os
|
| 2 |
import random
|
| 3 |
import time
|
| 4 |
+
from typing import Any, Tuple
|
|
|
|
| 5 |
from env import UIEnv, Action, Observation
|
| 6 |
|
| 7 |
VALID_ACTIONS = [
|
|
|
|
| 48 |
return Action(type="noop")
|
| 49 |
|
| 50 |
|
| 51 |
+
def llm_policy(client: Any, obs: Observation, model_name: str) -> Action:
|
| 52 |
state_desc = (
|
| 53 |
f"Device: {obs.device}\n"
|
| 54 |
f"Button Size: {obs.layout.button_size:.2f}\n"
|
|
|
|
| 111 |
return Action(type="noop")
|
| 112 |
|
| 113 |
|
| 114 |
+
def agent_policy(client: Any, obs: Observation, model_name: str) -> Action:
|
| 115 |
heuristic_action = heuristic_policy(obs)
|
| 116 |
if heuristic_action.type != "noop":
|
| 117 |
return heuristic_action
|
|
|
|
| 119 |
return llm_policy(client, obs, model_name)
|
| 120 |
|
| 121 |
|
| 122 |
+
def run_episode(env: UIEnv, client: Any, model_name: str) -> Tuple[float, bool]:
|
| 123 |
obs = env.reset()
|
| 124 |
total_reward = 0.0
|
| 125 |
done = False
|
|
|
|
| 143 |
return total_reward, completed
|
| 144 |
|
| 145 |
|
| 146 |
+
def evaluate_task(task: str, client: Any, model_name: str, n_episodes: int = 1) -> Tuple[float, float, float]:
|
| 147 |
total_rewards = 0.0
|
| 148 |
completions = 0
|
| 149 |
|
|
|
|
| 170 |
api_key = os.getenv("API_KEY") or os.getenv("HF_TOKEN")
|
| 171 |
|
| 172 |
if api_key:
|
| 173 |
+
from openai import OpenAI
|
| 174 |
print("Using Proxy / HF Router API...")
|
| 175 |
client = OpenAI(
|
| 176 |
base_url=base_url,
|
env.py
CHANGED
|
@@ -16,12 +16,15 @@ from pydantic import BaseModel, Field, model_validator
|
|
| 16 |
EPS = 1e-6
|
| 17 |
|
| 18 |
def safe_grader(fn):
|
| 19 |
-
def wrapper(x):
|
| 20 |
try:
|
| 21 |
if x is None:
|
| 22 |
-
val =
|
| 23 |
else:
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
if not isinstance(val, (int, float)):
|
| 27 |
val = 0.5
|
|
@@ -179,35 +182,6 @@ class UIEnv:
|
|
| 179 |
print(f"GRADER FAIL [{t.name}]: {e}")
|
| 180 |
print(f"VALID TASKS: {valid_tasks}")
|
| 181 |
|
| 182 |
-
inputs = [None, {}, "test", 0, 1]
|
| 183 |
-
|
| 184 |
-
print("=== VALIDATOR SIMULATION ===")
|
| 185 |
-
|
| 186 |
-
valid = 0
|
| 187 |
-
|
| 188 |
-
for t in self.tasks:
|
| 189 |
-
task_valid = True
|
| 190 |
-
|
| 191 |
-
for inp in inputs:
|
| 192 |
-
try:
|
| 193 |
-
val = t.grader(inp)
|
| 194 |
-
print(t.name, inp, val)
|
| 195 |
-
|
| 196 |
-
if not isinstance(val, (int, float)):
|
| 197 |
-
task_valid = False
|
| 198 |
-
|
| 199 |
-
if not (0 < val < 1):
|
| 200 |
-
task_valid = False
|
| 201 |
-
|
| 202 |
-
except Exception as e:
|
| 203 |
-
print("CRASH:", t.name, inp, e)
|
| 204 |
-
task_valid = False
|
| 205 |
-
|
| 206 |
-
if task_valid:
|
| 207 |
-
valid += 1
|
| 208 |
-
|
| 209 |
-
print("VALID TASKS:", valid)
|
| 210 |
-
|
| 211 |
self._layout: Layout = Layout()
|
| 212 |
self._device: Literal["mobile", "desktop"] = "desktop"
|
| 213 |
self._progress: float = 0.0
|
|
|
|
| 16 |
EPS = 1e-6
|
| 17 |
|
| 18 |
def safe_grader(fn):
|
| 19 |
+
def wrapper(x=None):
|
| 20 |
try:
|
| 21 |
if x is None:
|
| 22 |
+
val = fn()
|
| 23 |
else:
|
| 24 |
+
try:
|
| 25 |
+
val = fn(x)
|
| 26 |
+
except TypeError:
|
| 27 |
+
val = fn()
|
| 28 |
|
| 29 |
if not isinstance(val, (int, float)):
|
| 30 |
val = 0.5
|
|
|
|
| 182 |
print(f"GRADER FAIL [{t.name}]: {e}")
|
| 183 |
print(f"VALID TASKS: {valid_tasks}")
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
self._layout: Layout = Layout()
|
| 186 |
self._device: Literal["mobile", "desktop"] = "desktop"
|
| 187 |
self._progress: float = 0.0
|
inference.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
-
import json
|
| 3 |
import argparse
|
| 4 |
from typing import List, Optional
|
| 5 |
from env import UIEnv, Observation, Action, clamp_score
|
| 6 |
-
from openai import OpenAI
|
| 7 |
|
| 8 |
# Required Environment Variables
|
| 9 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
|
@@ -42,6 +40,7 @@ def run_inference(task_id: str = "easy") -> None:
|
|
| 42 |
# 2. Setup Client
|
| 43 |
client = None
|
| 44 |
if API_KEY:
|
|
|
|
| 45 |
client = OpenAI(
|
| 46 |
base_url=API_BASE_URL,
|
| 47 |
api_key=API_KEY
|
|
@@ -97,4 +96,4 @@ if __name__ == "__main__":
|
|
| 97 |
parser.add_argument("--task", type=str, default=default_task, help="Task difficulty (easy, medium, hard)")
|
| 98 |
args = parser.parse_args()
|
| 99 |
|
| 100 |
-
run_inference(task_id=args.task)
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import argparse
|
| 3 |
from typing import List, Optional
|
| 4 |
from env import UIEnv, Observation, Action, clamp_score
|
|
|
|
| 5 |
|
| 6 |
# Required Environment Variables
|
| 7 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
|
|
|
| 40 |
# 2. Setup Client
|
| 41 |
client = None
|
| 42 |
if API_KEY:
|
| 43 |
+
from openai import OpenAI
|
| 44 |
client = OpenAI(
|
| 45 |
base_url=API_BASE_URL,
|
| 46 |
api_key=API_KEY
|
|
|
|
| 96 |
parser.add_argument("--task", type=str, default=default_task, help="Task difficulty (easy, medium, hard)")
|
| 97 |
args = parser.parse_args()
|
| 98 |
|
| 99 |
+
run_inference(task_id=args.task)
|
openenv_core.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from fastapi import FastAPI
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def create_app(env_cls, action_model, observation_model) -> FastAPI:
|
| 7 |
+
"""Minimal local compatibility layer for this project.
|
| 8 |
+
|
| 9 |
+
The upstream package offers a richer OpenEnv integration, but this repo
|
| 10 |
+
only relies on receiving a FastAPI app instance to extend with routes.
|
| 11 |
+
"""
|
| 12 |
+
app = FastAPI()
|
| 13 |
+
app.state.env_cls = env_cls
|
| 14 |
+
app.state.action_model = action_model
|
| 15 |
+
app.state.observation_model = observation_model
|
| 16 |
+
return app
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def validate() -> None:
|
| 20 |
+
from validate_local import main
|
| 21 |
+
|
| 22 |
+
main()
|