Spaces:
Sleeping
Sleeping
File size: 14,075 Bytes
c90be96 4d21b0a c90be96 4d21b0a c90be96 4d21b0a c90be96 4d21b0a c90be96 4d21b0a c90be96 06d5215 4d21b0a c90be96 1e7104f c90be96 1e7104f c90be96 1e7104f c90be96 4d21b0a 1b9f5b0 4d21b0a 1b9f5b0 4d21b0a c90be96 1e7104f 4d21b0a 1e7104f 4d21b0a 1e7104f 4d21b0a 1e7104f 4d21b0a 1e7104f 06d5215 4d21b0a 1e7104f 4d21b0a 06d5215 4d21b0a 1e7104f 4d21b0a c90be96 1e7104f c90be96 4d21b0a c90be96 4d21b0a c90be96 1e7104f c90be96 4d21b0a c90be96 4d21b0a c90be96 4d21b0a c90be96 4d21b0a c90be96 4d21b0a c90be96 4d21b0a c90be96 4d21b0a c90be96 4d21b0a c90be96 4d21b0a c90be96 4d21b0a c90be96 4d21b0a c90be96 4d21b0a c90be96 4d21b0a c90be96 06d5215 4d21b0a 06d5215 1e7104f 4d21b0a 06d5215 4d21b0a c90be96 1e7104f c90be96 1e7104f c90be96 4d21b0a c90be96 4d21b0a c90be96 4d21b0a c90be96 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 | """
inference.py β WebSocket-based inference script for DataSelectEnv
Connects to the environment via WebSocket (/ws) β the required transport
on HF Spaces where HTTP /reset and /step are not accessible.
Usage:
export HF_TOKEN=hf_...
export API_BASE_URL=https://router.huggingface.co/v1
export MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct
export ENV_HOST=https://your-space.hf.space # or http://localhost:7860
python inference.py [--host URL]
Runs all 3 tasks sequentially using one WebSocket connection per task,
calls POST /grader after each episode, prints scores and final summary.
STDOUT FORMAT (required by validator):
[START] task=<task_name> env=DataSelectEnv model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<0.00> rewards=<r1,r2,...,rn>
"""
import argparse
import asyncio
import json
import os
import sys
from typing import List, Optional
import httpx
import requests
import websockets
from openai import OpenAI
# ---------------------------------------------------------------------------
# Config β all overridable via environment variables
# ---------------------------------------------------------------------------
DEFAULT_HOST = os.environ.get("ENV_HOST", "http://localhost:7860")
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN")
BENCHMARK = "DataSelectEnv"
SEED = 42
TASKS = ["easy", "medium", "hard"]
SYSTEM_PROMPT = """You are an intelligent data curation agent.
Your goal is to select high-quality training data from a noisy pool to improve
a machine learning classifier. At each step you observe the current state and
must choose a data selection strategy.
Observation fields:
- remaining_budget: samples you can still select (integer)
- diversity_score: std-dev of current training set features (higher = more diverse)
- noise_estimate: fraction of noisy (mislabelled) samples remaining in pool
- current_performance: validation score = 1/(1+log_loss), range [0,1]
- samples_available: unlabelled samples remaining in the pool
Respond with ONLY a valid JSON action in this exact format:
{
"action_type": "select_batch",
"batch_size": <integer 5-20>,
"strategy_weights": {
"uncertainty": <float 0-1>,
"diversity": <float 0-1>,
"random": <float 0-1>
}
}
Strategy rules:
- Weights are normalized automatically (no need to sum to 1)
- noise_estimate > 0.2 -> lower uncertainty weight, raise diversity weight
- noise_estimate > 0.4 -> set uncertainty near 0, maximize diversity
- diversity_score < 0.5 -> increase diversity weight
- remaining_budget < 30 -> reduce batch_size to 5
- You may use "action_type": "stop" with batch_size 0 only when
current_performance > 0.65 AND remaining_budget < 20
- Respond with ONLY the JSON object, no explanation, no markdown fences."""
# ---------------------------------------------------------------------------
# Structured log helpers (validator-required format)
# ---------------------------------------------------------------------------
def log_start(task: str, model: str) -> None:
print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True)
def log_step(step: int, action: dict, reward: float, done: bool,
error: Optional[str] = None) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={json.dumps(action)} "
f"reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
# Clamp score to (0.001, 0.999) strictly β validator rejects exact 0.0 or 1.0
score = max(0.001, min(0.999, score))
rewards_str = ",".join(f"{r:.4f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} "
f"score={score:.4f} rewards={rewards_str}",
flush=True,
)
# ---------------------------------------------------------------------------
# Rule-based fallback (used when LLM call fails)
# ---------------------------------------------------------------------------
def rule_based_action(obs: dict) -> dict:
"""Adaptive rule-based action derived from observation."""
noise = obs.get("noise_estimate", 0.1)
diversity = obs.get("diversity_score", 1.0)
budget = obs.get("remaining_budget", 100)
perf = obs.get("current_performance", 0.5)
batch_size = 5 if budget < 30 else 10
if noise > 0.4:
u, d, r = 0.05, 0.80, 0.15
elif noise > 0.2:
u, d, r = 0.20, 0.60, 0.20
elif diversity < 0.5:
u, d, r = 0.30, 0.55, 0.15
else:
u, d, r = 0.40, 0.40, 0.20
if perf > 0.65 and budget < 20:
return {"action_type": "stop", "batch_size": 0,
"strategy_weights": {"uncertainty": u, "diversity": d, "random": r}}
return {
"action_type": "select_batch",
"batch_size": batch_size,
"strategy_weights": {"uncertainty": u, "diversity": d, "random": r},
}
# ---------------------------------------------------------------------------
# OpenAI client factory β robust against proxy/env issues in containers
# ---------------------------------------------------------------------------
def make_openai_client(api_key: str) -> OpenAI:
"""
Create the required OpenAI client (as mandated by the spec).
Uses an explicit httpx.Client with trust_env=False to bypass proxy
auto-detection that commonly breaks SDK init in containerised environments.
"""
base_url = (API_BASE_URL or "https://router.huggingface.co/v1").strip().rstrip("/")
http_client = httpx.Client(trust_env=False)
try:
return OpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
except Exception:
return OpenAI(api_key=api_key, http_client=http_client)
# ---------------------------------------------------------------------------
# LLM helper β uses the required OpenAI client
# ---------------------------------------------------------------------------
def query_llm(client: OpenAI, obs: dict) -> dict:
"""Ask the LLM to produce an action given the current observation."""
user_msg = (
f"Current observation:\n{json.dumps(obs, indent=2)}\n\n"
"What action do you take?"
)
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_msg},
],
temperature=0.0,
max_tokens=200,
)
raw = response.choices[0].message.content.strip()
# Strip markdown fences if model wraps JSON
if raw.startswith("```"):
raw = raw.split("```")[1]
if raw.startswith("json"):
raw = raw[4:]
action = json.loads(raw.strip())
assert "action_type" in action
assert "batch_size" in action
assert "strategy_weights" in action
return action
# ---------------------------------------------------------------------------
# WebSocket episode runner
# ---------------------------------------------------------------------------
def http_base(host: str) -> str:
return host.rstrip("/")
def ws_url(host: str) -> str:
base = http_base(host)
if base.startswith("https://"):
return "wss://" + base[len("https://"):] + "/ws"
if base.startswith("http://"):
return "ws://" + base[len("http://"):] + "/ws"
return base + "/ws"
async def run_task_ws(host: str, client: Optional[OpenAI], task_id: str) -> dict:
"""Run one full episode for task_id over WebSocket. Returns grader result."""
url = ws_url(host)
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
obs = {}
episode_id = "unknown"
log_start(task=task_id, model=MODEL_NAME)
try:
async with websockets.connect(url, open_timeout=30, ping_interval=20) as ws:
# ββ reset ββββββββββββββββββββββββββββββββββββββββββββββββββββ
await ws.send(json.dumps({
"type": "reset",
"data": {"task_id": task_id, "seed": SEED},
}))
resp = json.loads(await ws.recv())
if resp["type"] == "error":
raise RuntimeError(f"reset error: {resp['data']['message']}")
episode_id = resp["data"]["episode_id"]
obs = resp["data"]["observation"]
done = False
# ββ step loop ββββββββββββββββββββββββββββββββββββββββββββββββ
while not done:
step_num = len(rewards) + 1
last_error: Optional[str] = None
# Try LLM; fall back to rule-based on any failure
try:
if client is None:
raise ValueError("no LLM client")
action = query_llm(client, obs)
except Exception as e:
last_error = f"{type(e).__name__}: {e}"
action = rule_based_action(obs)
await ws.send(json.dumps({"type": "step", "data": action}))
resp = json.loads(await ws.recv())
if resp["type"] == "error":
err_msg = resp["data"]["message"]
log_step(step_num, action, 0.0, True, error=err_msg)
rewards.append(0.0)
steps_taken = step_num
break
data = resp["data"]
obs = data["observation"]
raw_reward = data["reward"]
reward = raw_reward["value"] if isinstance(raw_reward, dict) else float(raw_reward)
done = data["done"]
rewards.append(reward)
steps_taken = step_num
log_step(step_num, action, reward, done, error=last_error)
# ββ close WebSocket cleanly βββββββββββββββββββββββββββββββββββ
await ws.send(json.dumps({"type": "close", "data": {}}))
try:
await asyncio.wait_for(ws.recv(), timeout=2.0)
except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed):
pass
# ββ grade via HTTP ββββββββββββββββββββββββββββββββββββββββββββββββ
r = requests.post(
f"{http_base(host)}/grader",
json={"episode_id": episode_id, "task_id": task_id},
timeout=15,
)
r.raise_for_status()
grade = r.json()
score = float(grade["score"])
success = bool(grade["passed"])
except Exception as exc:
print(f"[DEBUG] Episode error for {task_id}: {exc}", flush=True)
score = 0.0
success = False
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return {
"task_id": task_id,
"score": score,
"passed": success,
"steps": steps_taken,
"total_reward": round(sum(rewards), 4),
"final_performance": obs.get("current_performance", 0.0),
}
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
async def amain(host: str, client: Optional[OpenAI]) -> None:
results = {}
for task_id in TASKS:
results[task_id] = await run_task_ws(host, client, task_id)
print(f"\n{'='*52}", flush=True)
print(" INFERENCE RESULTS SUMMARY", flush=True)
print(f"{'='*52}", flush=True)
print(f"{'Task':<10} {'Score':<8} {'Passed':<8} {'Final Perf':<12} {'Steps'}", flush=True)
print("-" * 52, flush=True)
for task_id, r in results.items():
print(
f"{task_id:<10} {r['score']:<8.4f} {str(r['passed']):<8} "
f"{r['final_performance']:<12.4f} {r['steps']}",
flush=True,
)
overall = sum(r["score"] for r in results.values()) / len(results)
print(f"\nOverall mean score: {overall:.4f}", flush=True)
print(json.dumps({"results": results, "mean_score": round(overall, 4)}, indent=2), flush=True)
def main() -> None:
parser = argparse.ArgumentParser(description="DataSelectEnv WebSocket inference script")
parser.add_argument("--host", default=DEFAULT_HOST,
help="Environment server base URL (http or https)")
args = parser.parse_args()
# Build OpenAI client using HF_TOKEN (required by spec)
client: Optional[OpenAI] = None
if not HF_TOKEN:
print("WARNING: HF_TOKEN not set β using rule-based fallback.", flush=True)
else:
try:
client = make_openai_client(HF_TOKEN)
print(f"OpenAI client ready | base_url={API_BASE_URL} | model={MODEL_NAME}", flush=True)
except Exception as e:
print(f"WARNING: Could not init OpenAI client ({e}); using rule-based fallback.", flush=True)
# Health check β environment must be reachable
try:
r = requests.get(f"{http_base(args.host)}/health", timeout=15)
r.raise_for_status()
print(f"Environment health: {r.json()}", flush=True)
except Exception as e:
print(f"ERROR: Could not reach environment at {args.host}: {e}", flush=True)
sys.exit(1)
asyncio.run(amain(args.host, client))
if __name__ == "__main__":
main()
|