Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- Dockerfile +1 -1
- client.py +38 -0
- inference.py +60 -12
- server/app.py +12 -2
Dockerfile
CHANGED
|
@@ -28,4 +28,4 @@ EXPOSE 8000
|
|
| 28 |
HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \
|
| 29 |
CMD curl -fsS "http://127.0.0.1:${PORT}/health" || exit 1
|
| 30 |
|
| 31 |
-
CMD ["python", "-m", "uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
|
|
|
| 28 |
HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \
|
| 29 |
CMD curl -fsS "http://127.0.0.1:${PORT}/health" || exit 1
|
| 30 |
|
| 31 |
+
CMD ["python", "-m", "uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000", "--ws-ping-interval", "600", "--ws-ping-timeout", "600"]
|
client.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
"""Client for the metric tracker RL environment."""
|
| 2 |
|
|
|
|
| 3 |
from typing import Dict
|
| 4 |
|
| 5 |
from openenv.core import EnvClient
|
| 6 |
from openenv.core.client_types import StepResult
|
| 7 |
from openenv.core.env_server.types import State
|
|
|
|
| 8 |
|
| 9 |
from .models import MetricTrackerRlAction, MetricTrackerRlObservation
|
| 10 |
|
|
@@ -14,6 +16,42 @@ class MetricTrackerRlEnv(
|
|
| 14 |
):
|
| 15 |
"""Typed client for the metric tracking environment."""
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def _step_payload(self, action: MetricTrackerRlAction) -> Dict:
|
| 18 |
"""Serialize the action as JSON for the environment server."""
|
| 19 |
return action.model_dump()
|
|
|
|
| 1 |
"""Client for the metric tracker RL environment."""
|
| 2 |
|
| 3 |
+
import os
|
| 4 |
from typing import Dict
|
| 5 |
|
| 6 |
from openenv.core import EnvClient
|
| 7 |
from openenv.core.client_types import StepResult
|
| 8 |
from openenv.core.env_server.types import State
|
| 9 |
+
from websockets.asyncio.client import connect as ws_connect
|
| 10 |
|
| 11 |
from .models import MetricTrackerRlAction, MetricTrackerRlObservation
|
| 12 |
|
|
|
|
| 16 |
):
|
| 17 |
"""Typed client for the metric tracking environment."""
|
| 18 |
|
| 19 |
+
async def connect(self) -> "MetricTrackerRlEnv":
|
| 20 |
+
"""Connect with websocket keepalive disabled for long-running step calls."""
|
| 21 |
+
if self._ws is not None:
|
| 22 |
+
return self
|
| 23 |
+
|
| 24 |
+
ws_url_lower = self._ws_url.lower()
|
| 25 |
+
is_localhost = "localhost" in ws_url_lower or "127.0.0.1" in ws_url_lower
|
| 26 |
+
old_no_proxy = os.environ.get("NO_PROXY")
|
| 27 |
+
if is_localhost:
|
| 28 |
+
current_no_proxy = old_no_proxy or ""
|
| 29 |
+
if "localhost" not in current_no_proxy.lower():
|
| 30 |
+
os.environ["NO_PROXY"] = (
|
| 31 |
+
f"{current_no_proxy},localhost,127.0.0.1"
|
| 32 |
+
if current_no_proxy
|
| 33 |
+
else "localhost,127.0.0.1"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
self._ws = await ws_connect(
|
| 38 |
+
self._ws_url,
|
| 39 |
+
open_timeout=self._connect_timeout,
|
| 40 |
+
max_size=self._max_message_size,
|
| 41 |
+
ping_interval=None,
|
| 42 |
+
ping_timeout=None,
|
| 43 |
+
)
|
| 44 |
+
except Exception as exc:
|
| 45 |
+
raise ConnectionError(f"Failed to connect to {self._ws_url}: {exc}") from exc
|
| 46 |
+
finally:
|
| 47 |
+
if is_localhost:
|
| 48 |
+
if old_no_proxy is None:
|
| 49 |
+
os.environ.pop("NO_PROXY", None)
|
| 50 |
+
else:
|
| 51 |
+
os.environ["NO_PROXY"] = old_no_proxy
|
| 52 |
+
|
| 53 |
+
return self
|
| 54 |
+
|
| 55 |
def _step_payload(self, action: MetricTrackerRlAction) -> Dict:
|
| 56 |
"""Serialize the action as JSON for the environment server."""
|
| 57 |
return action.model_dump()
|
inference.py
CHANGED
|
@@ -9,7 +9,9 @@ import textwrap
|
|
| 9 |
from dataclasses import dataclass, field
|
| 10 |
from typing import Any
|
| 11 |
|
|
|
|
| 12 |
from openai import APIStatusError, OpenAI
|
|
|
|
| 13 |
|
| 14 |
from metric_tracker_rl import DEFAULT_TASK_ORDER, MetricTrackerRlAction, MetricTrackerRlEnv, get_task_spec
|
| 15 |
from metric_tracker_rl.analysis_tools import available_analysis_methods
|
|
@@ -20,7 +22,7 @@ from metric_tracker_rl.models import (
|
|
| 20 |
)
|
| 21 |
|
| 22 |
|
| 23 |
-
IMAGE_NAME = os.getenv("IMAGE_NAME") or "
|
| 24 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
|
| 25 |
API_BASE_URL = (
|
| 26 |
os.getenv("API_BASE_URL")
|
|
@@ -34,6 +36,10 @@ BENCHMARK = os.getenv("MetricTrackerRl_BENCHMARK", "metric_tracker_rl")
|
|
| 34 |
TEMPERATURE = float(os.getenv("TEMPERATURE", "0"))
|
| 35 |
MAX_TOKENS = min(int(os.getenv("MAX_TOKENS", "1000")), 4096)
|
| 36 |
MAX_TOOL_ROUNDS = int(os.getenv("MAX_TOOL_ROUNDS", "16"))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
SYSTEM_PROMPT = textwrap.dedent(
|
| 39 |
"""
|
|
@@ -276,8 +282,22 @@ def preview_text(text: str, limit: int = 220) -> str:
|
|
| 276 |
|
| 277 |
async def connect_env() -> MetricTrackerRlEnv:
|
| 278 |
if BASE_URL:
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
|
| 283 |
async def execute_tool_call(
|
|
@@ -541,24 +561,52 @@ async def run_single_task(
|
|
| 541 |
}
|
| 542 |
|
| 543 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 544 |
async def main() -> None:
|
| 545 |
if not API_KEY:
|
| 546 |
raise RuntimeError("Set OPENAI_API_KEY, HF_TOKEN, or API_KEY.")
|
| 547 |
|
| 548 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 549 |
-
env = await connect_env()
|
| 550 |
task_summaries: list[dict[str, Any]] = []
|
| 551 |
|
| 552 |
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
|
| 553 |
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
task_summaries.append(await run_single_task(client, env, task_id))
|
| 557 |
-
finally:
|
| 558 |
-
try:
|
| 559 |
-
await env.close()
|
| 560 |
-
except Exception:
|
| 561 |
-
pass
|
| 562 |
|
| 563 |
average_score = (
|
| 564 |
round(sum(item["normalized_score"] for item in task_summaries) / len(task_summaries), 6)
|
|
|
|
| 9 |
from dataclasses import dataclass, field
|
| 10 |
from typing import Any
|
| 11 |
|
| 12 |
+
from openenv.core.containers.runtime.providers import LocalDockerProvider
|
| 13 |
from openai import APIStatusError, OpenAI
|
| 14 |
+
from websockets.exceptions import ConnectionClosedError
|
| 15 |
|
| 16 |
from metric_tracker_rl import DEFAULT_TASK_ORDER, MetricTrackerRlAction, MetricTrackerRlEnv, get_task_spec
|
| 17 |
from metric_tracker_rl.analysis_tools import available_analysis_methods
|
|
|
|
| 22 |
)
|
| 23 |
|
| 24 |
|
| 25 |
+
IMAGE_NAME = (os.getenv("IMAGE_NAME") or "metric_tracker_rl:latest").strip()
|
| 26 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
|
| 27 |
API_BASE_URL = (
|
| 28 |
os.getenv("API_BASE_URL")
|
|
|
|
| 36 |
TEMPERATURE = float(os.getenv("TEMPERATURE", "0"))
|
| 37 |
MAX_TOKENS = min(int(os.getenv("MAX_TOKENS", "1000")), 4096)
|
| 38 |
MAX_TOOL_ROUNDS = int(os.getenv("MAX_TOOL_ROUNDS", "16"))
|
| 39 |
+
CONNECT_TIMEOUT_S = float(os.getenv("OPENENV_CONNECT_TIMEOUT_S", "30"))
|
| 40 |
+
MESSAGE_TIMEOUT_S = float(os.getenv("OPENENV_MESSAGE_TIMEOUT_S", "180"))
|
| 41 |
+
DOCKER_WAIT_TIMEOUT_S = float(os.getenv("OPENENV_DOCKER_WAIT_TIMEOUT_S", "120"))
|
| 42 |
+
TASK_RETRY_COUNT = int(os.getenv("OPENENV_TASK_RETRY_COUNT", "1"))
|
| 43 |
|
| 44 |
SYSTEM_PROMPT = textwrap.dedent(
|
| 45 |
"""
|
|
|
|
| 282 |
|
| 283 |
async def connect_env() -> MetricTrackerRlEnv:
|
| 284 |
if BASE_URL:
|
| 285 |
+
client = MetricTrackerRlEnv(
|
| 286 |
+
base_url=BASE_URL,
|
| 287 |
+
connect_timeout_s=CONNECT_TIMEOUT_S,
|
| 288 |
+
message_timeout_s=MESSAGE_TIMEOUT_S,
|
| 289 |
+
)
|
| 290 |
+
return await client.connect()
|
| 291 |
+
provider = LocalDockerProvider()
|
| 292 |
+
base_url = provider.start_container(IMAGE_NAME)
|
| 293 |
+
provider.wait_for_ready(base_url, timeout_s=DOCKER_WAIT_TIMEOUT_S)
|
| 294 |
+
client = MetricTrackerRlEnv(
|
| 295 |
+
base_url=base_url,
|
| 296 |
+
connect_timeout_s=CONNECT_TIMEOUT_S,
|
| 297 |
+
message_timeout_s=MESSAGE_TIMEOUT_S,
|
| 298 |
+
provider=provider,
|
| 299 |
+
)
|
| 300 |
+
return await client.connect()
|
| 301 |
|
| 302 |
|
| 303 |
async def execute_tool_call(
|
|
|
|
| 561 |
}
|
| 562 |
|
| 563 |
|
| 564 |
+
async def run_single_task_with_retries(
|
| 565 |
+
client: OpenAI,
|
| 566 |
+
task_id: str,
|
| 567 |
+
) -> dict[str, Any]:
|
| 568 |
+
"""Run one task with a fresh env connection and bounded reconnect retries."""
|
| 569 |
+
attempts = TASK_RETRY_COUNT + 1
|
| 570 |
+
last_error: Exception | None = None
|
| 571 |
+
|
| 572 |
+
for attempt in range(1, attempts + 1):
|
| 573 |
+
env = None
|
| 574 |
+
try:
|
| 575 |
+
env = await connect_env()
|
| 576 |
+
return await run_single_task(client, env, task_id)
|
| 577 |
+
except (ConnectionClosedError, ConnectionError, TimeoutError, OSError) as exc:
|
| 578 |
+
last_error = exc
|
| 579 |
+
print(
|
| 580 |
+
(
|
| 581 |
+
f"[WARN] task_id={task_id} attempt={attempt}/{attempts} "
|
| 582 |
+
f"env_connection_error={type(exc).__name__}: {exc}"
|
| 583 |
+
),
|
| 584 |
+
flush=True,
|
| 585 |
+
)
|
| 586 |
+
if attempt >= attempts:
|
| 587 |
+
raise
|
| 588 |
+
finally:
|
| 589 |
+
try:
|
| 590 |
+
if env is not None:
|
| 591 |
+
await env.close()
|
| 592 |
+
except Exception:
|
| 593 |
+
pass
|
| 594 |
+
|
| 595 |
+
assert last_error is not None
|
| 596 |
+
raise last_error
|
| 597 |
+
|
| 598 |
+
|
| 599 |
async def main() -> None:
|
| 600 |
if not API_KEY:
|
| 601 |
raise RuntimeError("Set OPENAI_API_KEY, HF_TOKEN, or API_KEY.")
|
| 602 |
|
| 603 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
|
|
|
| 604 |
task_summaries: list[dict[str, Any]] = []
|
| 605 |
|
| 606 |
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
|
| 607 |
|
| 608 |
+
for task_id in DEFAULT_TASK_ORDER:
|
| 609 |
+
task_summaries.append(await run_single_task_with_retries(client, task_id))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
|
| 611 |
average_score = (
|
| 612 |
round(sum(item["normalized_score"] for item in task_summaries) / len(task_summaries), 6)
|
server/app.py
CHANGED
|
@@ -28,6 +28,8 @@ Usage:
|
|
| 28 |
python -m server.app
|
| 29 |
"""
|
| 30 |
|
|
|
|
|
|
|
| 31 |
try:
|
| 32 |
from openenv.core.env_server.http_server import create_app
|
| 33 |
except Exception as e: # pragma: no cover
|
|
@@ -74,8 +76,16 @@ def main(host: str = "0.0.0.0", port: int = 8000):
|
|
| 74 |
uvicorn metric_tracker_rl.server.app:app --workers 4
|
| 75 |
"""
|
| 76 |
import uvicorn
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
if __name__ == "__main__":
|
|
|
|
| 28 |
python -m server.app
|
| 29 |
"""
|
| 30 |
|
| 31 |
+
import os
|
| 32 |
+
|
| 33 |
try:
|
| 34 |
from openenv.core.env_server.http_server import create_app
|
| 35 |
except Exception as e: # pragma: no cover
|
|
|
|
| 76 |
uvicorn metric_tracker_rl.server.app:app --workers 4
|
| 77 |
"""
|
| 78 |
import uvicorn
|
| 79 |
+
ws_ping_interval = float(os.getenv("UVICORN_WS_PING_INTERVAL", "600"))
|
| 80 |
+
ws_ping_timeout = float(os.getenv("UVICORN_WS_PING_TIMEOUT", "600"))
|
| 81 |
+
|
| 82 |
+
uvicorn.run(
|
| 83 |
+
app,
|
| 84 |
+
host=host,
|
| 85 |
+
port=port,
|
| 86 |
+
ws_ping_interval=ws_ping_interval,
|
| 87 |
+
ws_ping_timeout=ws_ping_timeout,
|
| 88 |
+
)
|
| 89 |
|
| 90 |
|
| 91 |
if __name__ == "__main__":
|