Spaces:
Sleeping
Sleeping
Commit Β·
2dfa9b8
1
Parent(s): ef2bd10
bug fix
Browse files- inference.py +36 -8
inference.py
CHANGED
|
@@ -36,6 +36,7 @@ from models import Action
|
|
| 36 |
|
| 37 |
# ββ Configuration from env vars ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
|
|
|
| 39 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 40 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 41 |
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
|
@@ -124,6 +125,22 @@ class AuctioneerEnvClient:
|
|
| 124 |
self.task_id = task_id
|
| 125 |
self._client = httpx.AsyncClient(timeout=300.0)
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
@classmethod
|
| 128 |
async def from_docker_image(cls, image_name: str,
|
| 129 |
task_id: str = "easy_headline"):
|
|
@@ -266,9 +283,15 @@ def call_llm(client: OpenAI, system: str, user: str) -> dict:
|
|
| 266 |
|
| 267 |
|
| 268 |
# ββ Main episode loop βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 269 |
-
async def run_task(task_id: str, image_name: str
|
|
|
|
| 270 |
llm = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
rewards: List[float] = []
|
| 274 |
steps_taken = 0
|
|
@@ -329,11 +352,16 @@ async def run_task(task_id: str, image_name: str) -> float:
|
|
| 329 |
|
| 330 |
|
| 331 |
async def main() -> None:
|
| 332 |
-
#
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
if not API_KEY:
|
| 338 |
print("[ERROR] Set HF_TOKEN or API_KEY env var.", flush=True)
|
| 339 |
sys.exit(1)
|
|
@@ -346,7 +374,7 @@ async def main() -> None:
|
|
| 346 |
|
| 347 |
scores: Dict[str, float] = {}
|
| 348 |
for t in tasks:
|
| 349 |
-
scores[t] = await run_task(t,
|
| 350 |
|
| 351 |
# ββ Summary ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 352 |
print("\n" + "=" * 52)
|
|
|
|
| 36 |
|
| 37 |
# ββ Configuration from env vars ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 39 |
+
ENV_URL = os.getenv("ENV_URL") or os.getenv("SPACE_URL") or os.getenv("PING_URL")
|
| 40 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 41 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 42 |
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
|
|
|
| 125 |
self.task_id = task_id
|
| 126 |
self._client = httpx.AsyncClient(timeout=300.0)
|
| 127 |
|
| 128 |
+
@classmethod
|
| 129 |
+
async def from_url(cls, url: str, task_id: str = "easy_headline"):
|
| 130 |
+
"""Connect directly to a remote env server (e.g. HF Space)."""
|
| 131 |
+
inst = cls(base_url=url.rstrip("/"), container_id=None, task_id=task_id)
|
| 132 |
+
# Wait for the server to become ready
|
| 133 |
+
for _ in range(90):
|
| 134 |
+
try:
|
| 135 |
+
r = await inst._client.get(f"{inst.base_url}/health")
|
| 136 |
+
if r.status_code == 200:
|
| 137 |
+
print(f"[DEBUG] Connected to remote env at {url}", flush=True)
|
| 138 |
+
return inst
|
| 139 |
+
except Exception:
|
| 140 |
+
pass
|
| 141 |
+
await asyncio.sleep(1.0)
|
| 142 |
+
raise RuntimeError(f"Remote env at {url} did not become ready")
|
| 143 |
+
|
| 144 |
@classmethod
|
| 145 |
async def from_docker_image(cls, image_name: str,
|
| 146 |
task_id: str = "easy_headline"):
|
|
|
|
| 283 |
|
| 284 |
|
| 285 |
# ββ Main episode loop βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 286 |
+
async def run_task(task_id: str, image_name: Optional[str] = None,
|
| 287 |
+
env_url: Optional[str] = None) -> float:
|
| 288 |
llm = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 289 |
+
if env_url:
|
| 290 |
+
env = await AuctioneerEnvClient.from_url(env_url, task_id=task_id)
|
| 291 |
+
elif image_name:
|
| 292 |
+
env = await AuctioneerEnvClient.from_docker_image(image_name, task_id=task_id)
|
| 293 |
+
else:
|
| 294 |
+
raise RuntimeError("No env_url or image_name provided")
|
| 295 |
|
| 296 |
rewards: List[float] = []
|
| 297 |
steps_taken = 0
|
|
|
|
| 352 |
|
| 353 |
|
| 354 |
async def main() -> None:
|
| 355 |
+
# Allow connecting to a remote URL (HF Space), local Docker image, or default localhost
|
| 356 |
+
env_url = ENV_URL
|
| 357 |
+
image_name = IMAGE_NAME
|
| 358 |
+
|
| 359 |
+
# If neither ENV_URL nor LOCAL_IMAGE_NAME is set, auto-detect running env on localhost
|
| 360 |
+
if not image_name and not env_url and not os.getenv("NO_DOCKER"):
|
| 361 |
+
# Fallback: assume env is already running on localhost:7860 (Docker default)
|
| 362 |
+
env_url = "http://localhost:7860"
|
| 363 |
+
print(f"[DEBUG] No LOCAL_IMAGE_NAME or ENV_URL set, defaulting to {env_url}", flush=True)
|
| 364 |
+
|
| 365 |
if not API_KEY:
|
| 366 |
print("[ERROR] Set HF_TOKEN or API_KEY env var.", flush=True)
|
| 367 |
sys.exit(1)
|
|
|
|
| 374 |
|
| 375 |
scores: Dict[str, float] = {}
|
| 376 |
for t in tasks:
|
| 377 |
+
scores[t] = await run_task(t, image_name=image_name, env_url=env_url)
|
| 378 |
|
| 379 |
# ββ Summary ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 380 |
print("\n" + "=" * 52)
|