Spaces:
Sleeping
Sleeping
Complete OpenEnv hackathon submission
Browse files- Add pyproject.toml + uv.lock for openenv validate compliance
- Add server/app.py shim (importlib) for [project.scripts] entry point
- Rename baseline.py β inference.py (WebSocket-based LLM agent)
- Fix env.py: guaranteed label flip, noise trap, shaped reward, warmup
- Fix models.py: strategy_weights validator
- Fix server.py: /ws WebSocket, /grader with episode persistence, Reward model
- Fix openenv.yaml: sync medium/hard success criteria with grader
- Add .gitignore: exclude .DS_Store, __pycache__, .claude/
- Add websockets to requirements.txt
- Update README: HF frontmatter, actual baseline scores, full API docs
openenv validate: [OK] DataSelectEnv: Ready for multi-mode deployment
- .gitignore +31 -0
- README.md +133 -0
- baseline.py +0 -204
- env.py +45 -19
- inference.py +280 -0
- models.py +7 -1
- openenv.yaml +5 -4
- pyproject.toml +30 -0
- requirements.txt +2 -1
- server.py +185 -39
- server/__init__.py +0 -0
- server/app.py +39 -0
- uv.lock +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# macOS
|
| 2 |
+
.DS_Store
|
| 3 |
+
.AppleDouble
|
| 4 |
+
.LSOverride
|
| 5 |
+
|
| 6 |
+
# Python
|
| 7 |
+
__pycache__/
|
| 8 |
+
*.py[cod]
|
| 9 |
+
*.pyo
|
| 10 |
+
*.pyd
|
| 11 |
+
.Python
|
| 12 |
+
*.egg-info/
|
| 13 |
+
dist/
|
| 14 |
+
build/
|
| 15 |
+
.eggs/
|
| 16 |
+
|
| 17 |
+
# Environments
|
| 18 |
+
.env
|
| 19 |
+
.venv
|
| 20 |
+
env/
|
| 21 |
+
venv/
|
| 22 |
+
|
| 23 |
+
# Claude Code memory (local tooling only, not part of submission)
|
| 24 |
+
.claude/
|
| 25 |
+
|
| 26 |
+
# IDE
|
| 27 |
+
.vscode/
|
| 28 |
+
.idea/
|
| 29 |
+
|
| 30 |
+
# OS
|
| 31 |
+
Thumbs.db
|
README.md
CHANGED
|
@@ -1 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# DataSelectEnv β OpenEnv Environment for Data Curation in ML Training
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: DataSelectEnv
|
| 3 |
+
emoji: π€
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
# DataSelectEnv β OpenEnv Environment for Data Curation in ML Training
|
| 13 |
+
|
| 14 |
+
## Description
|
| 15 |
+
|
| 16 |
+
DataSelectEnv is an [OpenEnv](https://github.com/meta-pytorch/OpenEnv)-compliant reinforcement learning environment for the Meta PyTorch OpenEnv Hackathon. The environment models a core problem in real-world machine learning: given a large pool of candidate training examples β some clean, some mislabelled, all with a cost to acquire β which samples should an agent select to maximise model quality under a fixed labelling budget?
|
| 17 |
+
|
| 18 |
+
We implement a real incremental training loop using SGDClassifier, where agents must select training data under budget constraints capturing realistic ML dynamics including diminishing returns, redundancy penalties, and noise sensitivity. The agent observes the current classifier's validation performance, an estimate of remaining pool noise, training-set diversity, and budget left, then decides how many samples to select and which sampling strategy to weight. Three difficulty tiers (easy / medium / hard) vary the noise level, budget, and time horizon to test whether agents can adapt their strategy to the environment's constraints.
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## Observation Space
|
| 23 |
+
|
| 24 |
+
| Field | Type | Range | Description |
|
| 25 |
+
|---|---|---|---|
|
| 26 |
+
| `remaining_budget` | int | [0, budget] | Samples remaining in the selection budget |
|
| 27 |
+
| `diversity_score` | float | [0, β) | Std-dev of current training set features (proxy for diversity) |
|
| 28 |
+
| `noise_estimate` | float | [0, 1] | Fraction of noisy samples still remaining in the pool |
|
| 29 |
+
| `current_performance` | float | [0, 1] | Validation score: 1 / (1 + log_loss) |
|
| 30 |
+
| `samples_available` | int | [0, ~1100] | Number of unlabelled samples left in the pool |
|
| 31 |
+
|
| 32 |
+
---
|
| 33 |
+
|
| 34 |
+
## Action Space
|
| 35 |
+
|
| 36 |
+
| Field | Type | Values | Description |
|
| 37 |
+
|---|---|---|---|
|
| 38 |
+
| `action_type` | string | `select_batch`, `stop` | Select a batch of data or end the episode early |
|
| 39 |
+
| `batch_size` | int | β₯ 0 | Number of samples to select this step |
|
| 40 |
+
| `strategy_weights.uncertainty` | float | β₯ 0 | Weight for uncertainty sampling (highest-entropy samples) |
|
| 41 |
+
| `strategy_weights.diversity` | float | β₯ 0 | Weight for diversity sampling (farthest from training centroid) |
|
| 42 |
+
| `strategy_weights.random` | float | β₯ 0 | Weight for uniform random sampling |
|
| 43 |
+
|
| 44 |
+
Strategy weights are normalised internally and do not need to sum to 1.
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
## Tasks
|
| 49 |
+
|
| 50 |
+
| Name | flip_y | Budget | max_steps | Success criteria | Expected random score |
|
| 51 |
+
|---|---|---|---|---|---|
|
| 52 |
+
| `easy` | 0.05 | 300 | 15 | performance > 0.55 | ~0.60 |
|
| 53 |
+
| `medium` | 0.25 | 150 | 12 | performance > 0.52 AND avg noise ratio < 0.30 | ~0.40 |
|
| 54 |
+
| `hard` | 0.30 | 100 | 8 | performance > 0.53 (+ budget efficiency) | ~0.30 |
|
| 55 |
+
|
| 56 |
+
---
|
| 57 |
+
|
| 58 |
+
## Reward Function
|
| 59 |
+
|
| 60 |
+
```
|
| 61 |
+
gain = (new_performance - old_performance) * 5.0
|
| 62 |
+
+ 0.2 * std(selected_batch) # diversity bonus
|
| 63 |
+
+ 0.2 * (new_performance - old_performance) # alignment bonus
|
| 64 |
+
|
| 65 |
+
if redundancy > 0.8: gain *= 0.5 # redundancy penalty
|
| 66 |
+
if new_performance > 0.85: gain *= 0.7 # diminishing-returns cap
|
| 67 |
+
|
| 68 |
+
noise_penalty = 0.4 * noise_ratio_of_selected_batch
|
| 69 |
+
|
| 70 |
+
reward = gain
|
| 71 |
+
- 0.01 * batch_size # budget cost
|
| 72 |
+
- 0.3 * redundancy # cosine similarity to training centroid
|
| 73 |
+
- noise_penalty
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
---
|
| 77 |
+
|
| 78 |
+
## Setup
|
| 79 |
+
|
| 80 |
+
### Local (pip)
|
| 81 |
+
|
| 82 |
+
```bash
|
| 83 |
+
pip install -r requirements.txt
|
| 84 |
+
python server.py # starts on http://localhost:7860
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### Docker
|
| 88 |
+
|
| 89 |
+
```bash
|
| 90 |
+
docker build -t dataselectenv .
|
| 91 |
+
docker run -p 7860:7860 dataselectenv
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
### Run inference (LLM agent)
|
| 95 |
+
|
| 96 |
+
```bash
|
| 97 |
+
export HF_TOKEN=hf_... # or OPENAI_API_KEY=sk-...
|
| 98 |
+
export API_BASE_URL=https://api-inference.huggingface.co/v1 # optional
|
| 99 |
+
export MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct # optional
|
| 100 |
+
|
| 101 |
+
python inference.py --host http://localhost:7860
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
### API quick-start
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
# Reset an episode
|
| 108 |
+
curl -X POST http://localhost:7860/reset \
|
| 109 |
+
-H "Content-Type: application/json" \
|
| 110 |
+
-d '{"task_id": "easy", "seed": 42}'
|
| 111 |
+
|
| 112 |
+
# Take a step
|
| 113 |
+
curl -X POST http://localhost:7860/step \
|
| 114 |
+
-H "Content-Type: application/json" \
|
| 115 |
+
-d '{"action": {"action_type": "select_batch", "batch_size": 10,
|
| 116 |
+
"strategy_weights": {"uncertainty": 0.4, "diversity": 0.4, "random": 0.2}}}'
|
| 117 |
+
|
| 118 |
+
# Run the built-in baseline and get reproducible scores
|
| 119 |
+
curl http://localhost:7860/baseline
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
---
|
| 123 |
+
|
| 124 |
+
## Baseline Scores
|
| 125 |
+
|
| 126 |
+
Scores below are from the fixed balanced agent (`uncertainty=0.4, diversity=0.4, random=0.2`, seed=42) run via `GET /baseline`.
|
| 127 |
+
|
| 128 |
+
| Task | Score | Passed | Final performance |
|
| 129 |
+
|---|---|---|---|
|
| 130 |
+
| easy | 0.7020 | β
| 0.6904 |
|
| 131 |
+
| medium | 0.6600 | β
| 0.6569 |
|
| 132 |
+
| hard | 0.4174 | β
| 0.6176 |
|
| 133 |
+
|
| 134 |
+
Scores are from the fixed balanced agent (`uncertainty=0.4, diversity=0.4, random=0.2`, seed=42) via `GET /baseline`.
|
baseline.py
DELETED
|
@@ -1,204 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
baseline.py β Baseline inference script for DataSelectEnv
|
| 3 |
-
|
| 4 |
-
Uses the OpenAI API client (as required by OpenEnv spec) to run an LLM
|
| 5 |
-
agent against all 3 tasks and produce reproducible scores.
|
| 6 |
-
|
| 7 |
-
Usage:
|
| 8 |
-
export OPENAI_API_KEY=sk-...
|
| 9 |
-
python baseline.py [--host http://localhost:7860]
|
| 10 |
-
|
| 11 |
-
The agent is given the current observation as a JSON prompt and asked
|
| 12 |
-
to return an action. This tests whether an LLM can navigate the
|
| 13 |
-
data-curation environment without any fine-tuning.
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
import argparse
|
| 17 |
-
import json
|
| 18 |
-
import os
|
| 19 |
-
import sys
|
| 20 |
-
|
| 21 |
-
import requests
|
| 22 |
-
from openai import OpenAI
|
| 23 |
-
|
| 24 |
-
# ---------------------------------------------------------------------------
|
| 25 |
-
# Config
|
| 26 |
-
# ---------------------------------------------------------------------------
|
| 27 |
-
|
| 28 |
-
DEFAULT_HOST = os.environ.get("ENV_HOST", "http://localhost:7860")
|
| 29 |
-
SEED = 42
|
| 30 |
-
TASKS = ["easy", "medium", "hard"]
|
| 31 |
-
|
| 32 |
-
SYSTEM_PROMPT = """You are an intelligent data curation agent.
|
| 33 |
-
|
| 34 |
-
Your goal is to select high-quality training data from a pool to improve a machine learning model.
|
| 35 |
-
At each step you observe the current state and must choose a data selection strategy.
|
| 36 |
-
|
| 37 |
-
You will receive a JSON observation with these fields:
|
| 38 |
-
- remaining_budget: how many samples you can still select
|
| 39 |
-
- diversity_score: current diversity of training set (higher = more diverse)
|
| 40 |
-
- noise_estimate: estimated fraction of noisy samples remaining in pool
|
| 41 |
-
- current_performance: current model validation performance (higher = better)
|
| 42 |
-
- samples_available: number of samples left in the pool
|
| 43 |
-
|
| 44 |
-
You must respond with ONLY a valid JSON action in this exact format:
|
| 45 |
-
{
|
| 46 |
-
"action_type": "select_batch",
|
| 47 |
-
"batch_size": <integer between 5 and 20>,
|
| 48 |
-
"strategy_weights": {
|
| 49 |
-
"uncertainty": <float 0-1>,
|
| 50 |
-
"diversity": <float 0-1>,
|
| 51 |
-
"random": <float 0-1>
|
| 52 |
-
}
|
| 53 |
-
}
|
| 54 |
-
|
| 55 |
-
Rules:
|
| 56 |
-
- Weights do not need to sum to 1 (they are normalized automatically)
|
| 57 |
-
- If noise_estimate is high (>0.2), reduce uncertainty weight and increase diversity weight
|
| 58 |
-
- If diversity_score is low (<0.5), increase diversity weight
|
| 59 |
-
- If remaining_budget is low (<30), use smaller batch sizes
|
| 60 |
-
- You may use "action_type": "stop" to end early if performance is already good
|
| 61 |
-
- Respond with ONLY the JSON object, no explanation."""
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def query_llm(client: OpenAI, observation: dict) -> dict:
|
| 65 |
-
"""Ask the LLM to produce an action given the current observation."""
|
| 66 |
-
user_msg = f"Current observation:\n{json.dumps(observation, indent=2)}\n\nWhat action do you take?"
|
| 67 |
-
|
| 68 |
-
response = client.chat.completions.create(
|
| 69 |
-
model="gpt-4o-mini",
|
| 70 |
-
messages=[
|
| 71 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 72 |
-
{"role": "user", "content": user_msg},
|
| 73 |
-
],
|
| 74 |
-
temperature=0.0, # deterministic
|
| 75 |
-
max_tokens=200,
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
raw = response.choices[0].message.content.strip()
|
| 79 |
-
|
| 80 |
-
# Strip markdown fences if the model wraps the JSON
|
| 81 |
-
if raw.startswith("```"):
|
| 82 |
-
raw = raw.split("```")[1]
|
| 83 |
-
if raw.startswith("json"):
|
| 84 |
-
raw = raw[4:]
|
| 85 |
-
raw = raw.strip()
|
| 86 |
-
|
| 87 |
-
return json.loads(raw)
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def run_task(host: str, client: OpenAI, task_id: str) -> dict:
|
| 91 |
-
"""Run one full episode on a task and return the grader result."""
|
| 92 |
-
print(f"\n{'='*50}")
|
| 93 |
-
print(f"Task: {task_id.upper()}")
|
| 94 |
-
print(f"{'='*50}")
|
| 95 |
-
|
| 96 |
-
# Reset
|
| 97 |
-
r = requests.post(f"{host}/reset", json={"task_id": task_id, "seed": SEED})
|
| 98 |
-
r.raise_for_status()
|
| 99 |
-
data = r.json()
|
| 100 |
-
episode_id = data["episode_id"]
|
| 101 |
-
obs = data["observation"]
|
| 102 |
-
print(f"Episode ID: {episode_id}")
|
| 103 |
-
print(f"Initial obs: {obs}")
|
| 104 |
-
|
| 105 |
-
step = 0
|
| 106 |
-
total_reward = 0.0
|
| 107 |
-
|
| 108 |
-
while True:
|
| 109 |
-
step += 1
|
| 110 |
-
|
| 111 |
-
# Get action from LLM
|
| 112 |
-
try:
|
| 113 |
-
action = query_llm(client, obs)
|
| 114 |
-
except (json.JSONDecodeError, KeyError) as e:
|
| 115 |
-
print(f" Step {step}: LLM returned invalid JSON ({e}), using fallback action")
|
| 116 |
-
action = {
|
| 117 |
-
"action_type": "select_batch",
|
| 118 |
-
"batch_size": 10,
|
| 119 |
-
"strategy_weights": {"uncertainty": 0.4, "diversity": 0.4, "random": 0.2},
|
| 120 |
-
}
|
| 121 |
-
|
| 122 |
-
# Execute action
|
| 123 |
-
r = requests.post(f"{host}/step", json={"action": action})
|
| 124 |
-
r.raise_for_status()
|
| 125 |
-
result = r.json()
|
| 126 |
-
|
| 127 |
-
obs = result["observation"]
|
| 128 |
-
reward = result["reward"]
|
| 129 |
-
done = result["done"]
|
| 130 |
-
total_reward += reward
|
| 131 |
-
|
| 132 |
-
print(f" Step {step:2d} | perf={obs['current_performance']:.4f} "
|
| 133 |
-
f"budget={obs['remaining_budget']:3d} reward={reward:+.4f} "
|
| 134 |
-
f"noise_est={obs['noise_estimate']:.3f}")
|
| 135 |
-
|
| 136 |
-
if done:
|
| 137 |
-
break
|
| 138 |
-
|
| 139 |
-
print(f"\nEpisode done after {step} steps | total_reward={total_reward:.4f}")
|
| 140 |
-
print(f"Final performance: {obs['current_performance']:.4f}")
|
| 141 |
-
|
| 142 |
-
# Grade
|
| 143 |
-
r = requests.post(f"{host}/grader", json={"episode_id": episode_id, "task_id": task_id})
|
| 144 |
-
r.raise_for_status()
|
| 145 |
-
grade = r.json()
|
| 146 |
-
|
| 147 |
-
print(f"Score: {grade['score']:.4f}")
|
| 148 |
-
print(f"Passed: {grade['passed']}")
|
| 149 |
-
print(f"Details: {grade['breakdown']}")
|
| 150 |
-
|
| 151 |
-
return {
|
| 152 |
-
"task_id": task_id,
|
| 153 |
-
"score": grade["score"],
|
| 154 |
-
"passed": grade["passed"],
|
| 155 |
-
"breakdown": grade["breakdown"],
|
| 156 |
-
"steps": step,
|
| 157 |
-
"total_reward": round(total_reward, 4),
|
| 158 |
-
"final_performance": obs["current_performance"],
|
| 159 |
-
}
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
def main():
|
| 163 |
-
parser = argparse.ArgumentParser(description="DataSelectEnv baseline inference script")
|
| 164 |
-
parser.add_argument("--host", default=DEFAULT_HOST, help="Environment server URL")
|
| 165 |
-
args = parser.parse_args()
|
| 166 |
-
|
| 167 |
-
api_key = os.environ.get("OPENAI_API_KEY")
|
| 168 |
-
if not api_key:
|
| 169 |
-
print("ERROR: OPENAI_API_KEY environment variable not set.")
|
| 170 |
-
sys.exit(1)
|
| 171 |
-
|
| 172 |
-
client = OpenAI(api_key=api_key)
|
| 173 |
-
|
| 174 |
-
# Health check
|
| 175 |
-
try:
|
| 176 |
-
r = requests.get(f"{args.host}/health", timeout=5)
|
| 177 |
-
r.raise_for_status()
|
| 178 |
-
print(f"Connected to {args.host} β {r.json()}")
|
| 179 |
-
except Exception as e:
|
| 180 |
-
print(f"ERROR: Could not reach environment at {args.host}: {e}")
|
| 181 |
-
sys.exit(1)
|
| 182 |
-
|
| 183 |
-
# Run all tasks
|
| 184 |
-
results = {}
|
| 185 |
-
for task_id in TASKS:
|
| 186 |
-
results[task_id] = run_task(args.host, client, task_id)
|
| 187 |
-
|
| 188 |
-
# Summary
|
| 189 |
-
print(f"\n{'='*50}")
|
| 190 |
-
print("BASELINE RESULTS SUMMARY")
|
| 191 |
-
print(f"{'='*50}")
|
| 192 |
-
print(f"{'Task':<10} {'Score':<8} {'Passed':<8} {'Final Perf':<12} {'Steps'}")
|
| 193 |
-
print("-" * 50)
|
| 194 |
-
for task_id, r in results.items():
|
| 195 |
-
print(f"{task_id:<10} {r['score']:<8.4f} {str(r['passed']):<8} "
|
| 196 |
-
f"{r['final_performance']:<12.4f} {r['steps']}")
|
| 197 |
-
|
| 198 |
-
overall = sum(r["score"] for r in results.values()) / len(results)
|
| 199 |
-
print(f"\nOverall mean score: {overall:.4f}")
|
| 200 |
-
print(json.dumps({"baseline_results": results, "mean_score": round(overall, 4)}, indent=2))
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
if __name__ == "__main__":
|
| 204 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
env.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import random
|
| 3 |
|
|
@@ -7,7 +8,7 @@ from sklearn.metrics import log_loss
|
|
| 7 |
from sklearn.preprocessing import StandardScaler
|
| 8 |
|
| 9 |
from models import Observation, Action, EnvState
|
| 10 |
-
from sampling import sample_uncertainty, sample_diversity, sample_random, entropy
|
| 11 |
from reward import mean_cosine, running_mean
|
| 12 |
|
| 13 |
|
|
@@ -57,7 +58,7 @@ class DataSelectEnv:
|
|
| 57 |
n_informative=5,
|
| 58 |
n_redundant=5,
|
| 59 |
n_clusters_per_class=2,
|
| 60 |
-
class_sep=1.
|
| 61 |
flip_y=0.1,
|
| 62 |
random_state=42, # dataset skeleton is fixed; noise injection varies by seed
|
| 63 |
)
|
|
@@ -75,9 +76,11 @@ class DataSelectEnv:
|
|
| 75 |
flip_prob = self.cfg["data"].get("flip_y", 0.1)
|
| 76 |
noise_mask = np.random.rand(len(y_pool)) < flip_prob
|
| 77 |
y_pool_noisy = y_pool.copy()
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
| 81 |
|
| 82 |
# Fresh model every episode
|
| 83 |
self.model = SGDClassifier(
|
|
@@ -90,7 +93,7 @@ class DataSelectEnv:
|
|
| 90 |
)
|
| 91 |
|
| 92 |
# Warm start on seed data
|
| 93 |
-
for _ in range(
|
| 94 |
self.model.partial_fit(X_seed, y_seed, classes=np.unique(y))
|
| 95 |
|
| 96 |
self._episode_state = DatasetState(
|
|
@@ -129,23 +132,34 @@ class DataSelectEnv:
|
|
| 129 |
total = sum(w.values()) + 1e-8
|
| 130 |
w = {k: v / total for k, v in w.items()}
|
| 131 |
|
|
|
|
| 132 |
b = min(action.batch_size, s.budget)
|
|
|
|
|
|
|
| 133 |
if b <= 0:
|
| 134 |
return self._obs(), -0.01, False, {"error": "empty batch"}
|
| 135 |
|
| 136 |
if action.action_type == "stop":
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
# Uncertainty + noise trap
|
| 140 |
proba_pool = self.model.predict_proba(s.X_pool)
|
| 141 |
H = entropy(proba_pool)
|
| 142 |
|
| 143 |
-
#
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
#
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
# Sampling
|
| 151 |
n_u = int(b * w.get("uncertainty", 0))
|
|
@@ -163,7 +177,7 @@ class DataSelectEnv:
|
|
| 163 |
|
| 164 |
Xb, yb = s.X_pool[idx], s.y_pool[idx]
|
| 165 |
selected_noise = s.noise_mask[idx]
|
| 166 |
-
noise_ratio = float(np.mean(selected_noise))
|
| 167 |
|
| 168 |
# Remove selected samples from pool β keep noise_mask in sync
|
| 169 |
keep = np.ones(len(s.X_pool), dtype=bool)
|
|
@@ -180,11 +194,18 @@ class DataSelectEnv:
|
|
| 180 |
new = self._score()
|
| 181 |
|
| 182 |
# ----------------------------------------------------------------
|
| 183 |
-
# Reward design
|
| 184 |
# ----------------------------------------------------------------
|
| 185 |
gain = (new - old) * 5.0
|
| 186 |
-
|
| 187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
redundancy = mean_cosine(Xb, s.train_centroid)
|
| 190 |
if redundancy > 0.8:
|
|
@@ -192,9 +213,14 @@ class DataSelectEnv:
|
|
| 192 |
if new > 0.85:
|
| 193 |
gain *= 0.7
|
| 194 |
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
reward = gain - 0.01 * b - 0.3 * redundancy - noise_penalty
|
| 197 |
-
reward += 0.
|
|
|
|
| 198 |
# ----------------------------------------------------------------
|
| 199 |
|
| 200 |
# Update state
|
|
|
|
| 1 |
+
import math
|
| 2 |
import numpy as np
|
| 3 |
import random
|
| 4 |
|
|
|
|
| 8 |
from sklearn.preprocessing import StandardScaler
|
| 9 |
|
| 10 |
from models import Observation, Action, EnvState
|
| 11 |
+
from sampling import sample_uncertainty, sample_diversity, sample_random, entropy
|
| 12 |
from reward import mean_cosine, running_mean
|
| 13 |
|
| 14 |
|
|
|
|
| 58 |
n_informative=5,
|
| 59 |
n_redundant=5,
|
| 60 |
n_clusters_per_class=2,
|
| 61 |
+
class_sep=1.0,
|
| 62 |
flip_y=0.1,
|
| 63 |
random_state=42, # dataset skeleton is fixed; noise injection varies by seed
|
| 64 |
)
|
|
|
|
| 76 |
flip_prob = self.cfg["data"].get("flip_y", 0.1)
|
| 77 |
noise_mask = np.random.rand(len(y_pool)) < flip_prob
|
| 78 |
y_pool_noisy = y_pool.copy()
|
| 79 |
+
# Guaranteed label flip (1-y), not random assignment.
|
| 80 |
+
# Random assignment gives the correct label 50% of the time, halving
|
| 81 |
+
# effective noise. Flipping guarantees every noise_mask sample is wrong.
|
| 82 |
+
y_pool_noisy[noise_mask] = 1 - y_pool[noise_mask]
|
| 83 |
+
X_pool[noise_mask] += np.random.normal(0, 0.1, X_pool[noise_mask].shape)
|
| 84 |
|
| 85 |
# Fresh model every episode
|
| 86 |
self.model = SGDClassifier(
|
|
|
|
| 93 |
)
|
| 94 |
|
| 95 |
# Warm start on seed data
|
| 96 |
+
for _ in range(10):
|
| 97 |
self.model.partial_fit(X_seed, y_seed, classes=np.unique(y))
|
| 98 |
|
| 99 |
self._episode_state = DatasetState(
|
|
|
|
| 132 |
total = sum(w.values()) + 1e-8
|
| 133 |
w = {k: v / total for k, v in w.items()}
|
| 134 |
|
| 135 |
+
min_b = self.cfg.get("min_batch", 1)
|
| 136 |
b = min(action.batch_size, s.budget)
|
| 137 |
+
if b < min_b and action.action_type != "stop":
|
| 138 |
+
b = min_b # enforce minimum; prevents single-sample gaming
|
| 139 |
if b <= 0:
|
| 140 |
return self._obs(), -0.01, False, {"error": "empty batch"}
|
| 141 |
|
| 142 |
if action.action_type == "stop":
|
| 143 |
+
perf_threshold = self.cfg.get("stop_threshold", 0.60)
|
| 144 |
+
if s.performance >= perf_threshold:
|
| 145 |
+
stop_reward = 0.05 * s.budget
|
| 146 |
+
else:
|
| 147 |
+
stop_reward = -1.0
|
| 148 |
+
return self._obs(), stop_reward, True, {}
|
| 149 |
|
| 150 |
# Uncertainty + noise trap
|
| 151 |
proba_pool = self.model.predict_proba(s.X_pool)
|
| 152 |
H = entropy(proba_pool)
|
| 153 |
|
| 154 |
+
# Noise trap: boost entropy of noisy samples so uncertainty sampling is
|
| 155 |
+
# attracted to them. Capped at 0.55 (< log(2) β 0.693 max binary entropy)
|
| 156 |
+
# so clean uncertain samples can still compete β trap misleads rather
|
| 157 |
+
# than completely overrides, keeping uncertainty a near-miss on hard.
|
| 158 |
+
max_entropy = math.log(2) # β 0.693 for binary classifier
|
| 159 |
+
flip_prob = self.cfg["data"].get("flip_y", 0.1)
|
| 160 |
+
boost_raw = 0.1 + flip_prob * 2.0
|
| 161 |
+
noise_boost = s.noise_mask.astype(float) * min(boost_raw, 0.55)
|
| 162 |
+
H_adj = H + noise_boost
|
| 163 |
|
| 164 |
# Sampling
|
| 165 |
n_u = int(b * w.get("uncertainty", 0))
|
|
|
|
| 177 |
|
| 178 |
Xb, yb = s.X_pool[idx], s.y_pool[idx]
|
| 179 |
selected_noise = s.noise_mask[idx]
|
| 180 |
+
noise_ratio = float(np.mean(selected_noise))
|
| 181 |
|
| 182 |
# Remove selected samples from pool β keep noise_mask in sync
|
| 183 |
keep = np.ones(len(s.X_pool), dtype=bool)
|
|
|
|
| 194 |
new = self._score()
|
| 195 |
|
| 196 |
# ----------------------------------------------------------------
|
| 197 |
+
# Reward design
|
| 198 |
# ----------------------------------------------------------------
|
| 199 |
gain = (new - old) * 5.0
|
| 200 |
+
|
| 201 |
+
# Distance-based diversity bonus: rewards batches that cover regions
|
| 202 |
+
# far from existing training data. Diversity sampling scores high
|
| 203 |
+
# (~0.25), random scores average (~0.22), uncertainty scores low
|
| 204 |
+
# (~0.15) because boundary samples cluster near the centroid.
|
| 205 |
+
diversity_bonus = float(np.mean(
|
| 206 |
+
np.linalg.norm(Xb - s.train_centroid, axis=1)
|
| 207 |
+
)) * 0.05
|
| 208 |
+
gain += diversity_bonus
|
| 209 |
|
| 210 |
redundancy = mean_cosine(Xb, s.train_centroid)
|
| 211 |
if redundancy > 0.8:
|
|
|
|
| 213 |
if new > 0.85:
|
| 214 |
gain *= 0.7
|
| 215 |
|
| 216 |
+
# Noise penalty scales with task difficulty: easy is forgiving,
|
| 217 |
+
# hard severely punishes noisy selections.
|
| 218 |
+
flip_prob = self.cfg["data"].get("flip_y", 0.1)
|
| 219 |
+
noise_scale = 1.0 + flip_prob * 2.0 # 1.1 easy | 1.5 medium | 1.6 hard
|
| 220 |
+
noise_penalty = noise_scale * noise_ratio
|
| 221 |
reward = gain - 0.01 * b - 0.3 * redundancy - noise_penalty
|
| 222 |
+
reward += 0.15 # baseline: keeps reward in mixed-sign territory so
|
| 223 |
+
# RL agents receive positive signal for good steps
|
| 224 |
# ----------------------------------------------------------------
|
| 225 |
|
| 226 |
# Update state
|
inference.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
inference.py β WebSocket-based inference script for DataSelectEnv
|
| 3 |
+
|
| 4 |
+
Connects to the environment via WebSocket (/ws) β the required transport
|
| 5 |
+
on HF Spaces where HTTP /reset and /step are not accessible.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
export HF_TOKEN=hf_... # or OPENAI_API_KEY=sk-...
|
| 9 |
+
export ENV_HOST=https://your-space.hf.space # or http://localhost:7860
|
| 10 |
+
export API_BASE_URL=https://api-inference.huggingface.co/v1 # optional
|
| 11 |
+
export MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct # optional
|
| 12 |
+
python inference.py [--host URL]
|
| 13 |
+
|
| 14 |
+
Runs all 3 tasks sequentially using one WebSocket connection per task,
|
| 15 |
+
calls POST /grader after each episode, prints scores and final summary.
|
| 16 |
+
Designed to complete in under 20 minutes on 2 vCPU / 8 GB RAM.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import asyncio
|
| 21 |
+
import json
|
| 22 |
+
import os
|
| 23 |
+
import sys
|
| 24 |
+
|
| 25 |
+
import requests
|
| 26 |
+
import websockets
|
| 27 |
+
from openai import OpenAI
|
| 28 |
+
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
# Config β all overridable via environment variables
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
DEFAULT_HOST = os.environ.get("ENV_HOST", "http://localhost:7860")
|
| 34 |
+
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 35 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 36 |
+
SEED = 42
|
| 37 |
+
TASKS = ["easy", "medium", "hard"]
|
| 38 |
+
FALLBACK_ACTION = {
|
| 39 |
+
"action_type": "select_batch",
|
| 40 |
+
"batch_size": 10,
|
| 41 |
+
"strategy_weights": {"uncertainty": 0.3, "diversity": 0.5, "random": 0.2},
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
SYSTEM_PROMPT = """You are an intelligent data curation agent.
|
| 45 |
+
|
| 46 |
+
Your goal is to select high-quality training data from a noisy pool to improve
|
| 47 |
+
a machine learning classifier. At each step you observe the current state and
|
| 48 |
+
must choose a data selection strategy.
|
| 49 |
+
|
| 50 |
+
Observation fields:
|
| 51 |
+
- remaining_budget: samples you can still select (integer)
|
| 52 |
+
- diversity_score: std-dev of current training set features (higher = more diverse)
|
| 53 |
+
- noise_estimate: fraction of noisy (mislabelled) samples remaining in pool
|
| 54 |
+
- current_performance: validation score = 1/(1+log_loss), range [0,1]
|
| 55 |
+
- samples_available: unlabelled samples remaining in the pool
|
| 56 |
+
|
| 57 |
+
Respond with ONLY a valid JSON action in this exact format:
|
| 58 |
+
{
|
| 59 |
+
"action_type": "select_batch",
|
| 60 |
+
"batch_size": <integer 5β20>,
|
| 61 |
+
"strategy_weights": {
|
| 62 |
+
"uncertainty": <float 0β1>,
|
| 63 |
+
"diversity": <float 0β1>,
|
| 64 |
+
"random": <float 0β1>
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
Strategy rules:
|
| 69 |
+
- Weights are normalized automatically (no need to sum to 1)
|
| 70 |
+
- noise_estimate > 0.2 β lower uncertainty weight, raise diversity weight
|
| 71 |
+
- noise_estimate > 0.4 β set uncertainty near 0, maximize diversity
|
| 72 |
+
- diversity_score < 0.5 β increase diversity weight
|
| 73 |
+
- remaining_budget < 30 β reduce batch_size to 5
|
| 74 |
+
- You may use "action_type": "stop" with batch_size 0 only when
|
| 75 |
+
current_performance > 0.65 AND remaining_budget < 20
|
| 76 |
+
- Respond with ONLY the JSON object, no explanation, no markdown fences."""
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ---------------------------------------------------------------------------
|
| 80 |
+
# LLM helper
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
|
| 83 |
+
def query_llm(client: OpenAI, observation: dict) -> dict:
|
| 84 |
+
"""Ask the LLM to produce an action given the current observation."""
|
| 85 |
+
user_msg = (
|
| 86 |
+
f"Current observation:\n{json.dumps(observation, indent=2)}\n\n"
|
| 87 |
+
"What action do you take?"
|
| 88 |
+
)
|
| 89 |
+
response = client.chat.completions.create(
|
| 90 |
+
model=MODEL_NAME,
|
| 91 |
+
messages=[
|
| 92 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 93 |
+
{"role": "user", "content": user_msg},
|
| 94 |
+
],
|
| 95 |
+
temperature=0.0,
|
| 96 |
+
max_tokens=200,
|
| 97 |
+
)
|
| 98 |
+
raw = response.choices[0].message.content.strip()
|
| 99 |
+
# Strip markdown fences if model wraps JSON
|
| 100 |
+
if raw.startswith("```"):
|
| 101 |
+
raw = raw.split("```")[1]
|
| 102 |
+
if raw.startswith("json"):
|
| 103 |
+
raw = raw[4:]
|
| 104 |
+
return json.loads(raw.strip())
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
# WebSocket episode runner
|
| 109 |
+
# ---------------------------------------------------------------------------
|
| 110 |
+
|
| 111 |
+
def http_base(host: str) -> str:
|
| 112 |
+
"""Return HTTP base URL (strip trailing slash)."""
|
| 113 |
+
return host.rstrip("/")
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def ws_url(host: str) -> str:
|
| 117 |
+
"""Convert http(s):// base URL to ws(s):// WebSocket URL."""
|
| 118 |
+
base = http_base(host)
|
| 119 |
+
if base.startswith("https://"):
|
| 120 |
+
return "wss://" + base[len("https://"):] + "/ws"
|
| 121 |
+
if base.startswith("http://"):
|
| 122 |
+
return "ws://" + base[len("http://"):] + "/ws"
|
| 123 |
+
return base + "/ws"
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
async def run_task_ws(host: str, client: OpenAI, task_id: str) -> dict:
|
| 127 |
+
"""
|
| 128 |
+
Run one full episode for task_id over a WebSocket connection.
|
| 129 |
+
Returns the grader result dict.
|
| 130 |
+
"""
|
| 131 |
+
print(f"\n{'='*52}")
|
| 132 |
+
print(f" Task: {task_id.upper()}")
|
| 133 |
+
print(f"{'='*52}")
|
| 134 |
+
|
| 135 |
+
url = ws_url(host)
|
| 136 |
+
print(f" Connecting to {url} ...")
|
| 137 |
+
|
| 138 |
+
async with websockets.connect(url, open_timeout=30, ping_interval=20) as ws:
|
| 139 |
+
|
| 140 |
+
# ββ reset ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 141 |
+
await ws.send(json.dumps({
|
| 142 |
+
"type": "reset",
|
| 143 |
+
"data": {"task_id": task_id, "seed": SEED},
|
| 144 |
+
}))
|
| 145 |
+
resp = json.loads(await ws.recv())
|
| 146 |
+
if resp["type"] == "error":
|
| 147 |
+
raise RuntimeError(f"reset error: {resp['data']['message']}")
|
| 148 |
+
|
| 149 |
+
episode_id = resp["data"]["episode_id"]
|
| 150 |
+
obs = resp["data"]["observation"]
|
| 151 |
+
print(f" Episode ID: {episode_id}")
|
| 152 |
+
print(f" Initial obs: {obs}")
|
| 153 |
+
|
| 154 |
+
step = 0
|
| 155 |
+
total_reward = 0.0
|
| 156 |
+
done = False
|
| 157 |
+
|
| 158 |
+
# ββ step loop ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 159 |
+
while not done:
|
| 160 |
+
step += 1
|
| 161 |
+
|
| 162 |
+
# Get action from LLM (with fallback on parse error)
|
| 163 |
+
try:
|
| 164 |
+
action = query_llm(client, obs)
|
| 165 |
+
# Validate required keys are present
|
| 166 |
+
assert "action_type" in action
|
| 167 |
+
assert "batch_size" in action
|
| 168 |
+
assert "strategy_weights" in action
|
| 169 |
+
except Exception as e:
|
| 170 |
+
print(f" Step {step}: LLM parse error ({e}), using fallback")
|
| 171 |
+
action = FALLBACK_ACTION
|
| 172 |
+
|
| 173 |
+
await ws.send(json.dumps({"type": "step", "data": action}))
|
| 174 |
+
resp = json.loads(await ws.recv())
|
| 175 |
+
|
| 176 |
+
if resp["type"] == "error":
|
| 177 |
+
print(f" Step {step}: server error: {resp['data']['message']}")
|
| 178 |
+
break
|
| 179 |
+
|
| 180 |
+
data = resp["data"]
|
| 181 |
+
obs = data["observation"]
|
| 182 |
+
# reward is wrapped in {"value": float} per Reward model
|
| 183 |
+
raw_reward = data["reward"]
|
| 184 |
+
reward = raw_reward["value"] if isinstance(raw_reward, dict) else float(raw_reward)
|
| 185 |
+
done = data["done"]
|
| 186 |
+
total_reward += reward
|
| 187 |
+
|
| 188 |
+
print(
|
| 189 |
+
f" Step {step:2d} | perf={obs['current_performance']:.4f} "
|
| 190 |
+
f"budget={obs['remaining_budget']:3d} "
|
| 191 |
+
f"reward={reward:+.4f} "
|
| 192 |
+
f"noise_est={obs['noise_estimate']:.3f}"
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# ββ close WebSocket cleanly βββββββββββββββββββββββββββββββββββββββ
|
| 196 |
+
await ws.send(json.dumps({"type": "close", "data": {}}))
|
| 197 |
+
try:
|
| 198 |
+
await asyncio.wait_for(ws.recv(), timeout=2.0)
|
| 199 |
+
except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed):
|
| 200 |
+
pass
|
| 201 |
+
|
| 202 |
+
print(f"\n Episode done after {step} steps | total_reward={total_reward:.4f}")
|
| 203 |
+
print(f" Final performance: {obs['current_performance']:.4f}")
|
| 204 |
+
|
| 205 |
+
# ββ grade via HTTP (grader endpoint doesn't need WebSocket) ββββββββββ
|
| 206 |
+
r = requests.post(
|
| 207 |
+
f"{http_base(host)}/grader",
|
| 208 |
+
json={"episode_id": episode_id, "task_id": task_id},
|
| 209 |
+
timeout=15,
|
| 210 |
+
)
|
| 211 |
+
r.raise_for_status()
|
| 212 |
+
grade = r.json()
|
| 213 |
+
|
| 214 |
+
print(f" Score: {grade['score']:.4f}")
|
| 215 |
+
print(f" Passed: {grade['passed']}")
|
| 216 |
+
print(f" Details: {grade['breakdown']}")
|
| 217 |
+
|
| 218 |
+
return {
|
| 219 |
+
"task_id": task_id,
|
| 220 |
+
"score": grade["score"],
|
| 221 |
+
"passed": grade["passed"],
|
| 222 |
+
"breakdown": grade["breakdown"],
|
| 223 |
+
"steps": step,
|
| 224 |
+
"total_reward": round(total_reward, 4),
|
| 225 |
+
"final_performance": obs["current_performance"],
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
# ---------------------------------------------------------------------------
|
| 230 |
+
# Main
|
| 231 |
+
# ---------------------------------------------------------------------------
|
| 232 |
+
|
| 233 |
+
async def amain(host: str, client: OpenAI) -> None:
|
| 234 |
+
results = {}
|
| 235 |
+
for task_id in TASKS:
|
| 236 |
+
results[task_id] = await run_task_ws(host, client, task_id)
|
| 237 |
+
|
| 238 |
+
print(f"\n{'='*52}")
|
| 239 |
+
print(" INFERENCE RESULTS SUMMARY")
|
| 240 |
+
print(f"{'='*52}")
|
| 241 |
+
print(f"{'Task':<10} {'Score':<8} {'Passed':<8} {'Final Perf':<12} {'Steps'}")
|
| 242 |
+
print("-" * 52)
|
| 243 |
+
for task_id, r in results.items():
|
| 244 |
+
print(
|
| 245 |
+
f"{task_id:<10} {r['score']:<8.4f} {str(r['passed']):<8} "
|
| 246 |
+
f"{r['final_performance']:<12.4f} {r['steps']}"
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
overall = sum(r["score"] for r in results.values()) / len(results)
|
| 250 |
+
print(f"\nOverall mean score: {overall:.4f}")
|
| 251 |
+
print(json.dumps({"results": results, "mean_score": round(overall, 4)}, indent=2))
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def main() -> None:
|
| 255 |
+
parser = argparse.ArgumentParser(description="DataSelectEnv WebSocket inference script")
|
| 256 |
+
parser.add_argument("--host", default=DEFAULT_HOST,
|
| 257 |
+
help="Environment server base URL (http or https)")
|
| 258 |
+
args = parser.parse_args()
|
| 259 |
+
|
| 260 |
+
api_key = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
|
| 261 |
+
if not api_key:
|
| 262 |
+
print("ERROR: Set HF_TOKEN or OPENAI_API_KEY environment variable.")
|
| 263 |
+
sys.exit(1)
|
| 264 |
+
|
| 265 |
+
client = OpenAI(api_key=api_key, base_url=API_BASE_URL)
|
| 266 |
+
|
| 267 |
+
# Health check over HTTP
|
| 268 |
+
try:
|
| 269 |
+
r = requests.get(f"{http_base(args.host)}/health", timeout=10)
|
| 270 |
+
r.raise_for_status()
|
| 271 |
+
print(f"Connected to {args.host} β {r.json()}")
|
| 272 |
+
except Exception as e:
|
| 273 |
+
print(f"ERROR: Could not reach environment at {args.host}: {e}")
|
| 274 |
+
sys.exit(1)
|
| 275 |
+
|
| 276 |
+
asyncio.run(amain(args.host, client))
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
if __name__ == "__main__":
|
| 280 |
+
main()
|
models.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from pydantic import BaseModel, Field
|
| 2 |
from typing import Dict, Literal, Optional
|
| 3 |
|
| 4 |
|
|
@@ -15,6 +15,12 @@ class Action(BaseModel):
|
|
| 15 |
batch_size: int = Field(ge=0)
|
| 16 |
strategy_weights: Dict[str, float]
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
class Reward(BaseModel):
|
| 20 |
value: float
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field, validator
|
| 2 |
from typing import Dict, Literal, Optional
|
| 3 |
|
| 4 |
|
|
|
|
| 15 |
batch_size: int = Field(ge=0)
|
| 16 |
strategy_weights: Dict[str, float]
|
| 17 |
|
| 18 |
+
@validator('strategy_weights')
|
| 19 |
+
def weights_not_empty(cls, v):
|
| 20 |
+
if not v:
|
| 21 |
+
raise ValueError('strategy_weights cannot be empty')
|
| 22 |
+
return v
|
| 23 |
+
|
| 24 |
|
| 25 |
class Reward(BaseModel):
|
| 26 |
value: float
|
openenv.yaml
CHANGED
|
@@ -78,16 +78,16 @@ tasks:
|
|
| 78 |
description: >
|
| 79 |
High noise (flip_y=0.25), budget=150, max_steps=12.
|
| 80 |
Agent must reach performance > 0.52 while keeping average
|
| 81 |
-
noise selection rate below 0.
|
| 82 |
-
success_criteria: "current_performance > 0.52 AND avg_noise_ratio < 0.
|
| 83 |
|
| 84 |
- id: hard
|
| 85 |
difficulty: hard
|
| 86 |
description: >
|
| 87 |
High noise (flip_y=0.30), tight budget=100, max_steps=8.
|
| 88 |
-
Agent must hit performance > 0.
|
| 89 |
Grader scores performance and budget efficiency jointly.
|
| 90 |
-
success_criteria: "current_performance > 0.
|
| 91 |
|
| 92 |
reward:
|
| 93 |
type: continuous
|
|
@@ -98,6 +98,7 @@ reward:
|
|
| 98 |
Provides dense signal throughout the episode β not just at termination.
|
| 99 |
|
| 100 |
endpoints:
|
|
|
|
| 101 |
reset: POST /reset
|
| 102 |
step: POST /step
|
| 103 |
state: GET /state
|
|
|
|
| 78 |
description: >
|
| 79 |
High noise (flip_y=0.25), budget=150, max_steps=12.
|
| 80 |
Agent must reach performance > 0.52 while keeping average
|
| 81 |
+
noise selection rate below 0.50.
|
| 82 |
+
success_criteria: "current_performance > 0.52 AND avg_noise_ratio < 0.50"
|
| 83 |
|
| 84 |
- id: hard
|
| 85 |
difficulty: hard
|
| 86 |
description: >
|
| 87 |
High noise (flip_y=0.30), tight budget=100, max_steps=8.
|
| 88 |
+
Agent must hit performance > 0.58 efficiently.
|
| 89 |
Grader scores performance and budget efficiency jointly.
|
| 90 |
+
success_criteria: "current_performance > 0.58, scored jointly with efficiency"
|
| 91 |
|
| 92 |
reward:
|
| 93 |
type: continuous
|
|
|
|
| 98 |
Provides dense signal throughout the episode β not just at termination.
|
| 99 |
|
| 100 |
endpoints:
|
| 101 |
+
websocket: WS /ws # primary transport; required on HF Spaces
|
| 102 |
reset: POST /reset
|
| 103 |
step: POST /step
|
| 104 |
state: GET /state
|
pyproject.toml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68", "wheel"]
|
| 3 |
+
build-backend = "setuptools.backends.legacy:build"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "dataselectenv"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "RL environment for optimal data selection under cost, noise, and diversity constraints"
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"scikit-learn",
|
| 12 |
+
"numpy",
|
| 13 |
+
"pydantic",
|
| 14 |
+
"fastapi",
|
| 15 |
+
"uvicorn[standard]",
|
| 16 |
+
"openai",
|
| 17 |
+
"websockets",
|
| 18 |
+
"python-dotenv",
|
| 19 |
+
"openenv-core>=0.2.0",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
[project.scripts]
|
| 23 |
+
server = "server.app:main"
|
| 24 |
+
|
| 25 |
+
[project.optional-dependencies]
|
| 26 |
+
dev = ["pytest", "httpx"]
|
| 27 |
+
|
| 28 |
+
[tool.openenv]
|
| 29 |
+
env_id = "DataSelectEnv-v0"
|
| 30 |
+
entry_point = "server:app"
|
requirements.txt
CHANGED
|
@@ -4,4 +4,5 @@ pydantic==2.7.4
|
|
| 4 |
numpy==1.26.4
|
| 5 |
scikit-learn==1.5.1
|
| 6 |
openai==1.40.0
|
| 7 |
-
requests==2.32.3
|
|
|
|
|
|
| 4 |
numpy==1.26.4
|
| 5 |
scikit-learn==1.5.1
|
| 6 |
openai==1.40.0
|
| 7 |
+
requests==2.32.3
|
| 8 |
+
websockets>=12.0
|
server.py
CHANGED
|
@@ -20,12 +20,12 @@ import uuid
|
|
| 20 |
from typing import Any, Dict, Optional
|
| 21 |
|
| 22 |
import numpy as np
|
| 23 |
-
from fastapi import FastAPI, HTTPException
|
| 24 |
from fastapi.middleware.cors import CORSMiddleware
|
| 25 |
from pydantic import BaseModel
|
| 26 |
|
| 27 |
from env import DataSelectEnv
|
| 28 |
-
from models import Action, EnvState, Observation
|
| 29 |
|
| 30 |
# ---------------------------------------------------------------------------
|
| 31 |
# App
|
|
@@ -63,6 +63,7 @@ BASE_CFG = {
|
|
| 63 |
"budget": 300,
|
| 64 |
"max_steps": 15,
|
| 65 |
"alpha": 0.2,
|
|
|
|
| 66 |
}
|
| 67 |
|
| 68 |
# ---------------------------------------------------------------------------
|
|
@@ -80,9 +81,10 @@ TASKS = {
|
|
| 80 |
),
|
| 81 |
"success_criteria": "current_performance > 0.55 at episode end",
|
| 82 |
"cfg_overrides": {
|
| 83 |
-
"data":
|
| 84 |
-
"budget":
|
| 85 |
-
"max_steps":
|
|
|
|
| 86 |
},
|
| 87 |
},
|
| 88 |
"medium": {
|
|
@@ -91,13 +93,14 @@ TASKS = {
|
|
| 91 |
"description": (
|
| 92 |
"High noise (flip_y=0.25), budget=150, max_steps=12. "
|
| 93 |
"Agent must reach performance > 0.52 while keeping average "
|
| 94 |
-
"noise selection rate below 0.
|
| 95 |
),
|
| 96 |
-
"success_criteria": "current_performance > 0.52 AND avg noise_ratio < 0.
|
| 97 |
"cfg_overrides": {
|
| 98 |
-
"data":
|
| 99 |
-
"budget":
|
| 100 |
-
"max_steps":
|
|
|
|
| 101 |
},
|
| 102 |
},
|
| 103 |
"hard": {
|
|
@@ -105,15 +108,16 @@ TASKS = {
|
|
| 105 |
"difficulty": "hard",
|
| 106 |
"description": (
|
| 107 |
"High noise (flip_y=0.30), tight budget=100, max_steps=8. "
|
| 108 |
-
"Agent must hit performance > 0.
|
| 109 |
"Grader scores performance and budget efficiency jointly. "
|
| 110 |
"Requires precise noise-aware + diversity-aware strategy."
|
| 111 |
),
|
| 112 |
-
"success_criteria": "performance > 0.
|
| 113 |
"cfg_overrides": {
|
| 114 |
-
"data":
|
| 115 |
-
"budget":
|
| 116 |
-
"max_steps":
|
|
|
|
| 117 |
},
|
| 118 |
},
|
| 119 |
}
|
|
@@ -159,6 +163,9 @@ class EpisodeStore:
|
|
| 159 |
|
| 160 |
store = EpisodeStore()
|
| 161 |
|
|
|
|
|
|
|
|
|
|
| 162 |
# ---------------------------------------------------------------------------
|
| 163 |
# Request / response schemas
|
| 164 |
# ---------------------------------------------------------------------------
|
|
@@ -199,19 +206,19 @@ def _grade(task_id: str, obs: Observation, noise_ratios: list, cfg: dict) -> Gra
|
|
| 199 |
perf = obs.current_performance
|
| 200 |
|
| 201 |
if task_id == "easy":
|
| 202 |
-
# Single dimension: raw performance
|
| 203 |
-
score = float(np.clip((perf - 0.
|
| 204 |
-
passed = perf > 0.
|
| 205 |
breakdown: Dict[str, Any] = {"performance_score": round(score, 4)}
|
| 206 |
|
| 207 |
elif task_id == "medium":
|
| 208 |
avg_noise = float(np.mean(noise_ratios)) if noise_ratios else 1.0
|
| 209 |
# Performance sub-score
|
| 210 |
perf_score = float(np.clip((perf - 0.42) / (0.62 - 0.42), 0.0, 1.0))
|
| 211 |
-
# Noise avoidance sub-score: full marks at 0 noise, zero at >=0.
|
| 212 |
-
noise_score = float(np.clip(1.0 - avg_noise / 0.
|
| 213 |
score = round(0.6 * perf_score + 0.4 * noise_score, 4)
|
| 214 |
-
passed = perf > 0.52 and avg_noise < 0.
|
| 215 |
breakdown = {
|
| 216 |
"performance_score": round(perf_score, 4),
|
| 217 |
"noise_score": round(noise_score, 4),
|
|
@@ -221,12 +228,12 @@ def _grade(task_id: str, obs: Observation, noise_ratios: list, cfg: dict) -> Gra
|
|
| 221 |
else: # hard
|
| 222 |
budget_total = cfg["budget"]
|
| 223 |
budget_used = budget_total - obs.remaining_budget
|
| 224 |
-
perf_score = float(np.clip((perf - 0.
|
| 225 |
-
# Efficiency:
|
| 226 |
-
#
|
| 227 |
-
efficiency = float(np.clip(1.0 - budget_used / budget_total
|
| 228 |
score = round(0.65 * perf_score + 0.35 * efficiency, 4)
|
| 229 |
-
passed = perf > 0.
|
| 230 |
breakdown = {
|
| 231 |
"performance_score": round(perf_score, 4),
|
| 232 |
"efficiency_score": round(efficiency, 4),
|
|
@@ -311,11 +318,19 @@ def step(req: StepRequest):
|
|
| 311 |
if "noise_ratio" in info:
|
| 312 |
store.noise_ratios.append(info["noise_ratio"])
|
| 313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
return {
|
| 315 |
"episode_id": store.episode_id,
|
| 316 |
"step": store.step_count,
|
| 317 |
"observation": obs.model_dump(),
|
| 318 |
-
"reward": round(float(reward), 6),
|
| 319 |
"done": done,
|
| 320 |
"info": info,
|
| 321 |
}
|
|
@@ -360,28 +375,40 @@ def tasks():
|
|
| 360 |
@app.post("/grader")
|
| 361 |
def grader(req: GraderRequest):
|
| 362 |
"""
|
| 363 |
-
Score
|
| 364 |
|
| 365 |
Body: { "episode_id": "...", "task_id": "easy|medium|hard" }
|
| 366 |
-
|
| 367 |
"""
|
| 368 |
-
if
|
| 369 |
-
raise HTTPException(
|
| 370 |
-
status_code=400,
|
| 371 |
-
detail="episode_id does not match the current episode.",
|
| 372 |
-
)
|
| 373 |
-
if not store.done:
|
| 374 |
raise HTTPException(
|
| 375 |
status_code=400,
|
| 376 |
-
detail="
|
| 377 |
)
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
raise HTTPException(
|
| 380 |
status_code=400,
|
| 381 |
-
detail=f"
|
| 382 |
)
|
| 383 |
|
| 384 |
-
return _grade(req.task_id,
|
| 385 |
|
| 386 |
|
| 387 |
@app.get("/baseline")
|
|
@@ -435,6 +462,125 @@ def baseline():
|
|
| 435 |
}
|
| 436 |
|
| 437 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
# ---------------------------------------------------------------------------
|
| 439 |
# Entry point
|
| 440 |
# ---------------------------------------------------------------------------
|
|
|
|
| 20 |
from typing import Any, Dict, Optional
|
| 21 |
|
| 22 |
import numpy as np
|
| 23 |
+
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
| 24 |
from fastapi.middleware.cors import CORSMiddleware
|
| 25 |
from pydantic import BaseModel
|
| 26 |
|
| 27 |
from env import DataSelectEnv
|
| 28 |
+
from models import Action, EnvState, Observation, Reward
|
| 29 |
|
| 30 |
# ---------------------------------------------------------------------------
|
| 31 |
# App
|
|
|
|
| 63 |
"budget": 300,
|
| 64 |
"max_steps": 15,
|
| 65 |
"alpha": 0.2,
|
| 66 |
+
"min_batch": 5,
|
| 67 |
}
|
| 68 |
|
| 69 |
# ---------------------------------------------------------------------------
|
|
|
|
| 81 |
),
|
| 82 |
"success_criteria": "current_performance > 0.55 at episode end",
|
| 83 |
"cfg_overrides": {
|
| 84 |
+
"data": {"flip_y": 0.05},
|
| 85 |
+
"budget": 300,
|
| 86 |
+
"max_steps": 15,
|
| 87 |
+
"stop_threshold": 0.60,
|
| 88 |
},
|
| 89 |
},
|
| 90 |
"medium": {
|
|
|
|
| 93 |
"description": (
|
| 94 |
"High noise (flip_y=0.25), budget=150, max_steps=12. "
|
| 95 |
"Agent must reach performance > 0.52 while keeping average "
|
| 96 |
+
"noise selection rate below 0.45. Uncertainty-only strategies fail."
|
| 97 |
),
|
| 98 |
+
"success_criteria": "current_performance > 0.52 AND avg noise_ratio < 0.45",
|
| 99 |
"cfg_overrides": {
|
| 100 |
+
"data": {"flip_y": 0.25},
|
| 101 |
+
"budget": 150,
|
| 102 |
+
"max_steps": 12,
|
| 103 |
+
"stop_threshold": 0.57,
|
| 104 |
},
|
| 105 |
},
|
| 106 |
"hard": {
|
|
|
|
| 108 |
"difficulty": "hard",
|
| 109 |
"description": (
|
| 110 |
"High noise (flip_y=0.30), tight budget=100, max_steps=8. "
|
| 111 |
+
"Agent must hit performance > 0.58 efficiently. "
|
| 112 |
"Grader scores performance and budget efficiency jointly. "
|
| 113 |
"Requires precise noise-aware + diversity-aware strategy."
|
| 114 |
),
|
| 115 |
+
"success_criteria": "performance > 0.58, scored jointly with budget efficiency",
|
| 116 |
"cfg_overrides": {
|
| 117 |
+
"data": {"flip_y": 0.30},
|
| 118 |
+
"budget": 100,
|
| 119 |
+
"max_steps": 8,
|
| 120 |
+
"stop_threshold": 0.62,
|
| 121 |
},
|
| 122 |
},
|
| 123 |
}
|
|
|
|
| 163 |
|
| 164 |
store = EpisodeStore()
|
| 165 |
|
| 166 |
+
# Completed episodes keyed by episode_id so /grader works after a subsequent reset()
|
| 167 |
+
_completed: Dict[str, Dict[str, Any]] = {}
|
| 168 |
+
|
| 169 |
# ---------------------------------------------------------------------------
|
| 170 |
# Request / response schemas
|
| 171 |
# ---------------------------------------------------------------------------
|
|
|
|
| 206 |
perf = obs.current_performance
|
| 207 |
|
| 208 |
if task_id == "easy":
|
| 209 |
+
# Single dimension: raw performance β range [0.55, 0.75] avoids saturation
|
| 210 |
+
score = float(np.clip((perf - 0.55) / (0.75 - 0.55), 0.0, 1.0))
|
| 211 |
+
passed = perf > 0.62
|
| 212 |
breakdown: Dict[str, Any] = {"performance_score": round(score, 4)}
|
| 213 |
|
| 214 |
elif task_id == "medium":
|
| 215 |
avg_noise = float(np.mean(noise_ratios)) if noise_ratios else 1.0
|
| 216 |
# Performance sub-score
|
| 217 |
perf_score = float(np.clip((perf - 0.42) / (0.62 - 0.42), 0.0, 1.0))
|
| 218 |
+
# Noise avoidance sub-score: full marks at 0 noise, zero at >=0.50
|
| 219 |
+
noise_score = float(np.clip(1.0 - avg_noise / 0.50, 0.0, 1.0))
|
| 220 |
score = round(0.6 * perf_score + 0.4 * noise_score, 4)
|
| 221 |
+
passed = perf > 0.52 and avg_noise < 0.50
|
| 222 |
breakdown = {
|
| 223 |
"performance_score": round(perf_score, 4),
|
| 224 |
"noise_score": round(noise_score, 4),
|
|
|
|
| 228 |
else: # hard
|
| 229 |
budget_total = cfg["budget"]
|
| 230 |
budget_used = budget_total - obs.remaining_budget
|
| 231 |
+
perf_score = float(np.clip((perf - 0.50) / (0.72 - 0.50), 0.0, 1.0))
|
| 232 |
+
# Efficiency: fraction of budget saved β no grace offset so it
|
| 233 |
+
# actually varies (0.0 = all spent, 1.0 = nothing spent)
|
| 234 |
+
efficiency = float(np.clip(1.0 - budget_used / budget_total, 0.0, 1.0))
|
| 235 |
score = round(0.65 * perf_score + 0.35 * efficiency, 4)
|
| 236 |
+
passed = perf > 0.58
|
| 237 |
breakdown = {
|
| 238 |
"performance_score": round(perf_score, 4),
|
| 239 |
"efficiency_score": round(efficiency, 4),
|
|
|
|
| 318 |
if "noise_ratio" in info:
|
| 319 |
store.noise_ratios.append(info["noise_ratio"])
|
| 320 |
|
| 321 |
+
if done:
|
| 322 |
+
_completed[store.episode_id] = {
|
| 323 |
+
"final_obs": obs,
|
| 324 |
+
"noise_ratios": list(store.noise_ratios),
|
| 325 |
+
"cfg": store._cfg,
|
| 326 |
+
"task_id": store.task_id,
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
return {
|
| 330 |
"episode_id": store.episode_id,
|
| 331 |
"step": store.step_count,
|
| 332 |
"observation": obs.model_dump(),
|
| 333 |
+
"reward": Reward(value=round(float(reward), 6)).model_dump(),
|
| 334 |
"done": done,
|
| 335 |
"info": info,
|
| 336 |
}
|
|
|
|
| 375 |
@app.post("/grader")
|
| 376 |
def grader(req: GraderRequest):
|
| 377 |
"""
|
| 378 |
+
Score a completed episode.
|
| 379 |
|
| 380 |
Body: { "episode_id": "...", "task_id": "easy|medium|hard" }
|
| 381 |
+
Works even after a subsequent reset() β looks up by episode_id.
|
| 382 |
"""
|
| 383 |
+
if req.task_id not in TASKS:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
raise HTTPException(
|
| 385 |
status_code=400,
|
| 386 |
+
detail=f"Unknown task_id '{req.task_id}'.",
|
| 387 |
)
|
| 388 |
+
|
| 389 |
+
record = _completed.get(req.episode_id)
|
| 390 |
+
if record is None:
|
| 391 |
+
# Fall back to the active episode if it matches and is done
|
| 392 |
+
if store.episode_id == req.episode_id and store.done:
|
| 393 |
+
record = {
|
| 394 |
+
"final_obs": store.final_obs,
|
| 395 |
+
"noise_ratios": store.noise_ratios,
|
| 396 |
+
"cfg": store._cfg,
|
| 397 |
+
"task_id": store.task_id,
|
| 398 |
+
}
|
| 399 |
+
else:
|
| 400 |
+
raise HTTPException(
|
| 401 |
+
status_code=404,
|
| 402 |
+
detail="episode_id not found or episode is not finished yet.",
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
if req.task_id != record["task_id"]:
|
| 406 |
raise HTTPException(
|
| 407 |
status_code=400,
|
| 408 |
+
detail=f"task_id mismatch: episode was '{record['task_id']}', got '{req.task_id}'.",
|
| 409 |
)
|
| 410 |
|
| 411 |
+
return _grade(req.task_id, record["final_obs"], record["noise_ratios"], record["cfg"])
|
| 412 |
|
| 413 |
|
| 414 |
@app.get("/baseline")
|
|
|
|
| 462 |
}
|
| 463 |
|
| 464 |
|
| 465 |
+
# ---------------------------------------------------------------------------
|
| 466 |
+
# WebSocket endpoint β required by OpenEnv spec; primary client transport on
|
| 467 |
+
# HF Spaces (HTTP /reset and /step are inaccessible after deployment there).
|
| 468 |
+
#
|
| 469 |
+
# Protocol: every message is {"type": str, "data": dict}
|
| 470 |
+
# Client β server types: "reset", "step", "state", "close"
|
| 471 |
+
# Server β client types: mirrors client type on success, "error" on failure
|
| 472 |
+
# ---------------------------------------------------------------------------
|
| 473 |
+
|
| 474 |
+
@app.websocket("/ws")
|
| 475 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 476 |
+
await websocket.accept()
|
| 477 |
+
|
| 478 |
+
# Per-connection isolated state (no shared store)
|
| 479 |
+
ws_env: DataSelectEnv | None = None
|
| 480 |
+
ws_cfg: dict | None = None
|
| 481 |
+
ws_episode_id: str | None = None
|
| 482 |
+
ws_task_id: str | None = None
|
| 483 |
+
ws_noise_ratios: list = []
|
| 484 |
+
ws_done: bool = False
|
| 485 |
+
ws_final_obs: Observation | None = None
|
| 486 |
+
|
| 487 |
+
async def send_error(message: str, code: str = "error") -> None:
|
| 488 |
+
await websocket.send_json({"type": "error", "data": {"message": message, "code": code}})
|
| 489 |
+
|
| 490 |
+
try:
|
| 491 |
+
while True:
|
| 492 |
+
raw = await websocket.receive_json()
|
| 493 |
+
msg_type = raw.get("type")
|
| 494 |
+
msg_data = raw.get("data", {})
|
| 495 |
+
|
| 496 |
+
# ββ reset βοΏ½οΏ½βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 497 |
+
if msg_type == "reset":
|
| 498 |
+
tid = msg_data.get("task_id", "easy")
|
| 499 |
+
seed = int(msg_data.get("seed", 42))
|
| 500 |
+
|
| 501 |
+
if tid not in TASKS:
|
| 502 |
+
await send_error(
|
| 503 |
+
f"Unknown task_id '{tid}'. Valid: {list(TASKS.keys())}",
|
| 504 |
+
"invalid_task",
|
| 505 |
+
)
|
| 506 |
+
continue
|
| 507 |
+
|
| 508 |
+
ws_cfg = _build_cfg(tid)
|
| 509 |
+
ws_env = DataSelectEnv(ws_cfg, seed=seed)
|
| 510 |
+
obs = ws_env.reset()
|
| 511 |
+
ws_task_id = tid
|
| 512 |
+
ws_episode_id = str(uuid.uuid4())
|
| 513 |
+
ws_noise_ratios = []
|
| 514 |
+
ws_done = False
|
| 515 |
+
ws_final_obs = obs
|
| 516 |
+
|
| 517 |
+
await websocket.send_json({
|
| 518 |
+
"type": "reset",
|
| 519 |
+
"data": {
|
| 520 |
+
"episode_id": ws_episode_id,
|
| 521 |
+
"task_id": ws_task_id,
|
| 522 |
+
"observation": obs.model_dump(),
|
| 523 |
+
"reward": 0.0,
|
| 524 |
+
"done": False,
|
| 525 |
+
},
|
| 526 |
+
})
|
| 527 |
+
|
| 528 |
+
# ββ step ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 529 |
+
elif msg_type == "step":
|
| 530 |
+
if ws_env is None or ws_done:
|
| 531 |
+
await send_error("No active episode. Send a reset message first.", "no_episode")
|
| 532 |
+
continue
|
| 533 |
+
|
| 534 |
+
try:
|
| 535 |
+
action = Action(**msg_data)
|
| 536 |
+
except Exception as exc:
|
| 537 |
+
await send_error(f"Invalid action: {exc}", "invalid_action")
|
| 538 |
+
continue
|
| 539 |
+
|
| 540 |
+
obs, reward, done, info = ws_env.step(action)
|
| 541 |
+
ws_done = done
|
| 542 |
+
ws_final_obs = obs
|
| 543 |
+
if "noise_ratio" in info:
|
| 544 |
+
ws_noise_ratios.append(info["noise_ratio"])
|
| 545 |
+
|
| 546 |
+
await websocket.send_json({
|
| 547 |
+
"type": "step",
|
| 548 |
+
"data": {
|
| 549 |
+
"episode_id": ws_episode_id,
|
| 550 |
+
"observation": obs.model_dump(),
|
| 551 |
+
"reward": round(float(reward), 6),
|
| 552 |
+
"done": done,
|
| 553 |
+
"info": info,
|
| 554 |
+
},
|
| 555 |
+
})
|
| 556 |
+
|
| 557 |
+
# ββ state βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 558 |
+
elif msg_type == "state":
|
| 559 |
+
if ws_env is None:
|
| 560 |
+
state_data = {
|
| 561 |
+
"step_count": 0, "remaining_budget": None,
|
| 562 |
+
"current_performance": None, "pool_size": None, "done": False,
|
| 563 |
+
}
|
| 564 |
+
else:
|
| 565 |
+
state_data = ws_env.get_state().model_dump()
|
| 566 |
+
|
| 567 |
+
await websocket.send_json({
|
| 568 |
+
"type": "state",
|
| 569 |
+
"data": {"episode_id": ws_episode_id, "task_id": ws_task_id, **state_data},
|
| 570 |
+
})
|
| 571 |
+
|
| 572 |
+
# ββ close βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 573 |
+
elif msg_type == "close":
|
| 574 |
+
await websocket.send_json({"type": "close", "data": {}})
|
| 575 |
+
break
|
| 576 |
+
|
| 577 |
+
else:
|
| 578 |
+
await send_error(f"Unknown message type '{msg_type}'", "unknown_type")
|
| 579 |
+
|
| 580 |
+
except WebSocketDisconnect:
|
| 581 |
+
pass # client disconnected cleanly
|
| 582 |
+
|
| 583 |
+
|
| 584 |
# ---------------------------------------------------------------------------
|
| 585 |
# Entry point
|
| 586 |
# ---------------------------------------------------------------------------
|
server/__init__.py
ADDED
|
File without changes
|
server/app.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
server/app.py β OpenEnv entry point shim.
|
| 3 |
+
|
| 4 |
+
Loads the root-level server.py by absolute file path so that the
|
| 5 |
+
server/ package and the root server.py file can coexist without
|
| 6 |
+
Python's import system preferring the package over the module.
|
| 7 |
+
"""
|
| 8 |
+
import importlib.util
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
# Add project root to path so root-level modules (env, models, etc.) resolve
|
| 13 |
+
_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
+
if _root not in sys.path:
|
| 15 |
+
sys.path.insert(0, _root)
|
| 16 |
+
|
| 17 |
+
# Load root server.py by file path, registered under a private module name
|
| 18 |
+
# to avoid collision with the `server` package name.
|
| 19 |
+
_spec = importlib.util.spec_from_file_location(
|
| 20 |
+
"_dataselectenv_server",
|
| 21 |
+
os.path.join(_root, "server.py"),
|
| 22 |
+
)
|
| 23 |
+
_mod = importlib.util.module_from_spec(_spec)
|
| 24 |
+
sys.modules["_dataselectenv_server"] = _mod
|
| 25 |
+
_spec.loader.exec_module(_mod)
|
| 26 |
+
|
| 27 |
+
# Re-export the FastAPI app β this is what openenv and uvicorn look for.
|
| 28 |
+
app = _mod.app
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def main() -> None:
|
| 32 |
+
"""Entry point required by openenv validate and [project.scripts]."""
|
| 33 |
+
import uvicorn
|
| 34 |
+
port = int(os.environ.get("PORT", 7860))
|
| 35 |
+
uvicorn.run(app, host="0.0.0.0", port=port, reload=False)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
main()
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|