kushalExplores commited on
Commit
821b7b8
·
verified ·
1 Parent(s): 80405b3

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. Dockerfile +1 -1
  2. client.py +38 -0
  3. inference.py +60 -12
  4. server/app.py +12 -2
Dockerfile CHANGED
@@ -28,4 +28,4 @@ EXPOSE 8000
28
  HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \
29
  CMD curl -fsS "http://127.0.0.1:${PORT}/health" || exit 1
30
 
31
- CMD ["python", "-m", "uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
 
28
  HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \
29
  CMD curl -fsS "http://127.0.0.1:${PORT}/health" || exit 1
30
 
31
+ CMD ["python", "-m", "uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000", "--ws-ping-interval", "600", "--ws-ping-timeout", "600"]
client.py CHANGED
@@ -1,10 +1,12 @@
1
  """Client for the metric tracker RL environment."""
2
 
 
3
  from typing import Dict
4
 
5
  from openenv.core import EnvClient
6
  from openenv.core.client_types import StepResult
7
  from openenv.core.env_server.types import State
 
8
 
9
  from .models import MetricTrackerRlAction, MetricTrackerRlObservation
10
 
@@ -14,6 +16,42 @@ class MetricTrackerRlEnv(
14
  ):
15
  """Typed client for the metric tracking environment."""
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def _step_payload(self, action: MetricTrackerRlAction) -> Dict:
18
  """Serialize the action as JSON for the environment server."""
19
  return action.model_dump()
 
1
  """Client for the metric tracker RL environment."""
2
 
3
+ import os
4
  from typing import Dict
5
 
6
  from openenv.core import EnvClient
7
  from openenv.core.client_types import StepResult
8
  from openenv.core.env_server.types import State
9
+ from websockets.asyncio.client import connect as ws_connect
10
 
11
  from .models import MetricTrackerRlAction, MetricTrackerRlObservation
12
 
 
16
  ):
17
  """Typed client for the metric tracking environment."""
18
 
19
+ async def connect(self) -> "MetricTrackerRlEnv":
20
+ """Connect with websocket keepalive disabled for long-running step calls."""
21
+ if self._ws is not None:
22
+ return self
23
+
24
+ ws_url_lower = self._ws_url.lower()
25
+ is_localhost = "localhost" in ws_url_lower or "127.0.0.1" in ws_url_lower
26
+ old_no_proxy = os.environ.get("NO_PROXY")
27
+ if is_localhost:
28
+ current_no_proxy = old_no_proxy or ""
29
+ if "localhost" not in current_no_proxy.lower():
30
+ os.environ["NO_PROXY"] = (
31
+ f"{current_no_proxy},localhost,127.0.0.1"
32
+ if current_no_proxy
33
+ else "localhost,127.0.0.1"
34
+ )
35
+
36
+ try:
37
+ self._ws = await ws_connect(
38
+ self._ws_url,
39
+ open_timeout=self._connect_timeout,
40
+ max_size=self._max_message_size,
41
+ ping_interval=None,
42
+ ping_timeout=None,
43
+ )
44
+ except Exception as exc:
45
+ raise ConnectionError(f"Failed to connect to {self._ws_url}: {exc}") from exc
46
+ finally:
47
+ if is_localhost:
48
+ if old_no_proxy is None:
49
+ os.environ.pop("NO_PROXY", None)
50
+ else:
51
+ os.environ["NO_PROXY"] = old_no_proxy
52
+
53
+ return self
54
+
55
  def _step_payload(self, action: MetricTrackerRlAction) -> Dict:
56
  """Serialize the action as JSON for the environment server."""
57
  return action.model_dump()
inference.py CHANGED
@@ -9,7 +9,9 @@ import textwrap
9
  from dataclasses import dataclass, field
10
  from typing import Any
11
 
 
12
  from openai import APIStatusError, OpenAI
 
13
 
14
  from metric_tracker_rl import DEFAULT_TASK_ORDER, MetricTrackerRlAction, MetricTrackerRlEnv, get_task_spec
15
  from metric_tracker_rl.analysis_tools import available_analysis_methods
@@ -20,7 +22,7 @@ from metric_tracker_rl.models import (
20
  )
21
 
22
 
23
- IMAGE_NAME = os.getenv("IMAGE_NAME") or "metric_tracker:latest"
24
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
25
  API_BASE_URL = (
26
  os.getenv("API_BASE_URL")
@@ -34,6 +36,10 @@ BENCHMARK = os.getenv("MetricTrackerRl_BENCHMARK", "metric_tracker_rl")
34
  TEMPERATURE = float(os.getenv("TEMPERATURE", "0"))
35
  MAX_TOKENS = min(int(os.getenv("MAX_TOKENS", "1000")), 4096)
36
  MAX_TOOL_ROUNDS = int(os.getenv("MAX_TOOL_ROUNDS", "16"))
 
 
 
 
37
 
38
  SYSTEM_PROMPT = textwrap.dedent(
39
  """
@@ -276,8 +282,22 @@ def preview_text(text: str, limit: int = 220) -> str:
276
 
277
  async def connect_env() -> MetricTrackerRlEnv:
278
  if BASE_URL:
279
- return MetricTrackerRlEnv(base_url=BASE_URL)
280
- return await MetricTrackerRlEnv.from_docker_image(IMAGE_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
 
283
  async def execute_tool_call(
@@ -541,24 +561,52 @@ async def run_single_task(
541
  }
542
 
543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  async def main() -> None:
545
  if not API_KEY:
546
  raise RuntimeError("Set OPENAI_API_KEY, HF_TOKEN, or API_KEY.")
547
 
548
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
549
- env = await connect_env()
550
  task_summaries: list[dict[str, Any]] = []
551
 
552
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
553
 
554
- try:
555
- for task_id in DEFAULT_TASK_ORDER:
556
- task_summaries.append(await run_single_task(client, env, task_id))
557
- finally:
558
- try:
559
- await env.close()
560
- except Exception:
561
- pass
562
 
563
  average_score = (
564
  round(sum(item["normalized_score"] for item in task_summaries) / len(task_summaries), 6)
 
9
  from dataclasses import dataclass, field
10
  from typing import Any
11
 
12
+ from openenv.core.containers.runtime.providers import LocalDockerProvider
13
  from openai import APIStatusError, OpenAI
14
+ from websockets.exceptions import ConnectionClosedError
15
 
16
  from metric_tracker_rl import DEFAULT_TASK_ORDER, MetricTrackerRlAction, MetricTrackerRlEnv, get_task_spec
17
  from metric_tracker_rl.analysis_tools import available_analysis_methods
 
22
  )
23
 
24
 
25
+ IMAGE_NAME = (os.getenv("IMAGE_NAME") or "metric_tracker_rl:latest").strip()
26
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
27
  API_BASE_URL = (
28
  os.getenv("API_BASE_URL")
 
36
  TEMPERATURE = float(os.getenv("TEMPERATURE", "0"))
37
  MAX_TOKENS = min(int(os.getenv("MAX_TOKENS", "1000")), 4096)
38
  MAX_TOOL_ROUNDS = int(os.getenv("MAX_TOOL_ROUNDS", "16"))
39
+ CONNECT_TIMEOUT_S = float(os.getenv("OPENENV_CONNECT_TIMEOUT_S", "30"))
40
+ MESSAGE_TIMEOUT_S = float(os.getenv("OPENENV_MESSAGE_TIMEOUT_S", "180"))
41
+ DOCKER_WAIT_TIMEOUT_S = float(os.getenv("OPENENV_DOCKER_WAIT_TIMEOUT_S", "120"))
42
+ TASK_RETRY_COUNT = int(os.getenv("OPENENV_TASK_RETRY_COUNT", "1"))
43
 
44
  SYSTEM_PROMPT = textwrap.dedent(
45
  """
 
282
 
283
  async def connect_env() -> MetricTrackerRlEnv:
284
  if BASE_URL:
285
+ client = MetricTrackerRlEnv(
286
+ base_url=BASE_URL,
287
+ connect_timeout_s=CONNECT_TIMEOUT_S,
288
+ message_timeout_s=MESSAGE_TIMEOUT_S,
289
+ )
290
+ return await client.connect()
291
+ provider = LocalDockerProvider()
292
+ base_url = provider.start_container(IMAGE_NAME)
293
+ provider.wait_for_ready(base_url, timeout_s=DOCKER_WAIT_TIMEOUT_S)
294
+ client = MetricTrackerRlEnv(
295
+ base_url=base_url,
296
+ connect_timeout_s=CONNECT_TIMEOUT_S,
297
+ message_timeout_s=MESSAGE_TIMEOUT_S,
298
+ provider=provider,
299
+ )
300
+ return await client.connect()
301
 
302
 
303
  async def execute_tool_call(
 
561
  }
562
 
563
 
564
+ async def run_single_task_with_retries(
565
+ client: OpenAI,
566
+ task_id: str,
567
+ ) -> dict[str, Any]:
568
+ """Run one task with a fresh env connection and bounded reconnect retries."""
569
+ attempts = TASK_RETRY_COUNT + 1
570
+ last_error: Exception | None = None
571
+
572
+ for attempt in range(1, attempts + 1):
573
+ env = None
574
+ try:
575
+ env = await connect_env()
576
+ return await run_single_task(client, env, task_id)
577
+ except (ConnectionClosedError, ConnectionError, TimeoutError, OSError) as exc:
578
+ last_error = exc
579
+ print(
580
+ (
581
+ f"[WARN] task_id={task_id} attempt={attempt}/{attempts} "
582
+ f"env_connection_error={type(exc).__name__}: {exc}"
583
+ ),
584
+ flush=True,
585
+ )
586
+ if attempt >= attempts:
587
+ raise
588
+ finally:
589
+ try:
590
+ if env is not None:
591
+ await env.close()
592
+ except Exception:
593
+ pass
594
+
595
+ assert last_error is not None
596
+ raise last_error
597
+
598
+
599
  async def main() -> None:
600
  if not API_KEY:
601
  raise RuntimeError("Set OPENAI_API_KEY, HF_TOKEN, or API_KEY.")
602
 
603
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
 
604
  task_summaries: list[dict[str, Any]] = []
605
 
606
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
607
 
608
+ for task_id in DEFAULT_TASK_ORDER:
609
+ task_summaries.append(await run_single_task_with_retries(client, task_id))
 
 
 
 
 
 
610
 
611
  average_score = (
612
  round(sum(item["normalized_score"] for item in task_summaries) / len(task_summaries), 6)
server/app.py CHANGED
@@ -28,6 +28,8 @@ Usage:
28
  python -m server.app
29
  """
30
 
 
 
31
  try:
32
  from openenv.core.env_server.http_server import create_app
33
  except Exception as e: # pragma: no cover
@@ -74,8 +76,16 @@ def main(host: str = "0.0.0.0", port: int = 8000):
74
  uvicorn metric_tracker_rl.server.app:app --workers 4
75
  """
76
  import uvicorn
77
-
78
- uvicorn.run(app, host=host, port=port)
 
 
 
 
 
 
 
 
79
 
80
 
81
  if __name__ == "__main__":
 
28
  python -m server.app
29
  """
30
 
31
+ import os
32
+
33
  try:
34
  from openenv.core.env_server.http_server import create_app
35
  except Exception as e: # pragma: no cover
 
76
  uvicorn metric_tracker_rl.server.app:app --workers 4
77
  """
78
  import uvicorn
79
+ ws_ping_interval = float(os.getenv("UVICORN_WS_PING_INTERVAL", "600"))
80
+ ws_ping_timeout = float(os.getenv("UVICORN_WS_PING_TIMEOUT", "600"))
81
+
82
+ uvicorn.run(
83
+ app,
84
+ host=host,
85
+ port=port,
86
+ ws_ping_interval=ws_ping_interval,
87
+ ws_ping_timeout=ws_ping_timeout,
88
+ )
89
 
90
 
91
  if __name__ == "__main__":