sanjayvk21 commited on
Commit
2dfa9b8
Β·
1 Parent(s): ef2bd10
Files changed (1) hide show
  1. inference.py +36 -8
inference.py CHANGED
@@ -36,6 +36,7 @@ from models import Action
36
 
37
  # ── Configuration from env vars ──────────────────────────────────────────────
38
  IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
 
39
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
40
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
41
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
@@ -124,6 +125,22 @@ class AuctioneerEnvClient:
124
  self.task_id = task_id
125
  self._client = httpx.AsyncClient(timeout=300.0)
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  @classmethod
128
  async def from_docker_image(cls, image_name: str,
129
  task_id: str = "easy_headline"):
@@ -266,9 +283,15 @@ def call_llm(client: OpenAI, system: str, user: str) -> dict:
266
 
267
 
268
  # ── Main episode loop ───────────────────────────────────────────────────────
269
- async def run_task(task_id: str, image_name: str) -> float:
 
270
  llm = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
271
- env = await AuctioneerEnvClient.from_docker_image(image_name, task_id=task_id)
 
 
 
 
 
272
 
273
  rewards: List[float] = []
274
  steps_taken = 0
@@ -329,11 +352,16 @@ async def run_task(task_id: str, image_name: str) -> float:
329
 
330
 
331
  async def main() -> None:
332
- # Check variables, default to sys.exit(1) so grader notices configuration faults
333
- if not IMAGE_NAME and not os.getenv("NO_DOCKER"):
334
- print("[ERROR] Set LOCAL_IMAGE_NAME env var to the Docker image name.",
335
- flush=True)
336
- sys.exit(1)
 
 
 
 
 
337
  if not API_KEY:
338
  print("[ERROR] Set HF_TOKEN or API_KEY env var.", flush=True)
339
  sys.exit(1)
@@ -346,7 +374,7 @@ async def main() -> None:
346
 
347
  scores: Dict[str, float] = {}
348
  for t in tasks:
349
- scores[t] = await run_task(t, IMAGE_NAME)
350
 
351
  # ── Summary ──────────────────────────────────────────────────────────
352
  print("\n" + "=" * 52)
 
36
 
37
  # ── Configuration from env vars ──────────────────────────────────────────────
38
  IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
39
+ ENV_URL = os.getenv("ENV_URL") or os.getenv("SPACE_URL") or os.getenv("PING_URL")
40
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
41
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
42
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
 
125
  self.task_id = task_id
126
  self._client = httpx.AsyncClient(timeout=300.0)
127
 
128
+ @classmethod
129
+ async def from_url(cls, url: str, task_id: str = "easy_headline"):
130
+ """Connect directly to a remote env server (e.g. HF Space)."""
131
+ inst = cls(base_url=url.rstrip("/"), container_id=None, task_id=task_id)
132
+ # Wait for the server to become ready
133
+ for _ in range(90):
134
+ try:
135
+ r = await inst._client.get(f"{inst.base_url}/health")
136
+ if r.status_code == 200:
137
+ print(f"[DEBUG] Connected to remote env at {url}", flush=True)
138
+ return inst
139
+ except Exception:
140
+ pass
141
+ await asyncio.sleep(1.0)
142
+ raise RuntimeError(f"Remote env at {url} did not become ready")
143
+
144
  @classmethod
145
  async def from_docker_image(cls, image_name: str,
146
  task_id: str = "easy_headline"):
 
283
 
284
 
285
  # ── Main episode loop ───────────────────────────────────────────────────────
286
+ async def run_task(task_id: str, image_name: Optional[str] = None,
287
+ env_url: Optional[str] = None) -> float:
288
  llm = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
289
+ if env_url:
290
+ env = await AuctioneerEnvClient.from_url(env_url, task_id=task_id)
291
+ elif image_name:
292
+ env = await AuctioneerEnvClient.from_docker_image(image_name, task_id=task_id)
293
+ else:
294
+ raise RuntimeError("No env_url or image_name provided")
295
 
296
  rewards: List[float] = []
297
  steps_taken = 0
 
352
 
353
 
354
  async def main() -> None:
355
+ # Allow connecting to a remote URL (HF Space), local Docker image, or default localhost
356
+ env_url = ENV_URL
357
+ image_name = IMAGE_NAME
358
+
359
+ # If neither ENV_URL nor LOCAL_IMAGE_NAME is set, auto-detect running env on localhost
360
+ if not image_name and not env_url and not os.getenv("NO_DOCKER"):
361
+ # Fallback: assume env is already running on localhost:7860 (Docker default)
362
+ env_url = "http://localhost:7860"
363
+ print(f"[DEBUG] No LOCAL_IMAGE_NAME or ENV_URL set, defaulting to {env_url}", flush=True)
364
+
365
  if not API_KEY:
366
  print("[ERROR] Set HF_TOKEN or API_KEY env var.", flush=True)
367
  sys.exit(1)
 
374
 
375
  scores: Dict[str, float] = {}
376
  for t in tasks:
377
+ scores[t] = await run_task(t, image_name=image_name, env_url=env_url)
378
 
379
  # ── Summary ──────────────────────────────────────────────────────────
380
  print("\n" + "=" * 52)