# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """Python Env Environment Client.""" from __future__ import annotations from typing import Any, Dict from urllib.parse import urlparse import httpx from openenv.core import EnvClient from openenv.core.client_types import StepResult try: from .models import ( HealthResponse, MetricsResponse, PythonAction, PythonObservation, PythonState, TaskListResponse, ) except ImportError: from models import ( # type: ignore HealthResponse, MetricsResponse, PythonAction, PythonObservation, PythonState, TaskListResponse, ) def _to_http_base_url(base_url: str) -> str: parsed = urlparse(base_url) scheme = "https" if parsed.scheme == "wss" else "http" if parsed.scheme in {"http", "https"}: scheme = parsed.scheme return f"{scheme}://{parsed.netloc}{parsed.path}".rstrip("/") class PythonEnv(EnvClient[PythonAction, PythonObservation, PythonState]): """Typed client for the Python code-review environment.""" def __init__(self, base_url: str, **kwargs: Any): super().__init__(base_url=base_url, **kwargs) self._http_base_url = _to_http_base_url(base_url) def _step_payload(self, action: PythonAction) -> Dict[str, Any]: """Convert a validated action model to the JSON payload expected by the server.""" return action.model_dump(exclude_none=True) def _parse_result(self, payload: Dict[str, Any]) -> StepResult[PythonObservation]: """Parse a server response into a typed step result.""" obs_data = dict(payload.get("observation", {})) obs_data.setdefault("done", payload.get("done", False)) obs_data.setdefault("reward", payload.get("reward")) observation = PythonObservation.model_validate(obs_data) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: Dict[str, Any]) -> PythonState: """Parse the server state payload into the shared state model.""" return PythonState.model_validate(payload) async def get_tasks(self) -> TaskListResponse: async with httpx.AsyncClient() as client: response = await client.get(f"{self._http_base_url}/tasks") response.raise_for_status() return TaskListResponse.model_validate(response.json()) async def get_metrics(self) -> MetricsResponse: async with httpx.AsyncClient() as client: response = await client.get(f"{self._http_base_url}/metrics") response.raise_for_status() return MetricsResponse.model_validate(response.json()) async def get_health(self) -> HealthResponse: async with httpx.AsyncClient() as client: response = await client.get(f"{self._http_base_url}/health") response.raise_for_status() return HealthResponse.model_validate(response.json())