Spaces:
Sleeping
Sleeping
Prasham.Jain
fix(submission): Dockerfile, wire-format fixes, scenario loading, real-scenario MockEnvClient
ba93ec0 | """HTTP client for the CI-Triage OpenEnv server. | |
| Wraps actions into the CITriageAction wire format ({kind, tool_call|terminal}) | |
| and unwraps CITriageObservation responses ({done, reward, metadata, payload}) | |
| back into domain Observation objects so the training loop can stay protocol-agnostic. | |
| Use MockEnvClient for offline training (no server required). | |
| """ | |
| from __future__ import annotations | |
| import httpx | |
| from ci_triage_env.schemas.action import TerminalAction, ToolCall | |
| from ci_triage_env.schemas.episode import EpisodeTrace | |
| from ci_triage_env.schemas.observation import Observation | |
| def _to_wire_action(action: ToolCall | TerminalAction | dict) -> dict: | |
| """Convert a domain action into the CITriageAction wire envelope.""" | |
| if isinstance(action, ToolCall): | |
| return { | |
| "kind": "tool_call", | |
| "tool_call": {"tool_name": action.tool_name, "args": action.args}, | |
| } | |
| if isinstance(action, TerminalAction): | |
| return { | |
| "kind": "submit_diagnosis", | |
| "terminal": action.model_dump(), | |
| } | |
| # Raw dict — infer kind from content | |
| if "kind" in action: | |
| return action | |
| if "tool_name" in action: | |
| return {"kind": "tool_call", "tool_call": action} | |
| if "action_type" in action or "diagnosis" in action: | |
| return {"kind": "submit_diagnosis", "terminal": action} | |
| return action | |
| def _unwrap_obs(data: dict) -> dict: | |
| """Unwrap CITriageObservation envelope → domain Observation dict.""" | |
| return data.get("payload", data) | |
| class EnvClient: | |
| """HTTP client for the CI-Triage env server (OpenEnv wire protocol). | |
| Args: | |
| base_url: Server base URL. Defaults to http://localhost:8000. | |
| timeout: Request timeout in seconds. | |
| """ | |
| def __init__( | |
| self, | |
| base_url: str = "http://localhost:8000", | |
| timeout: float = 30.0, | |
| ) -> None: | |
| self.base_url = base_url.rstrip("/") | |
| self._client = httpx.Client(base_url=self.base_url, timeout=timeout) | |
| def reset( | |
| self, | |
| scenario_id: str | None = None, | |
| seed_override: int | None = None, | |
| ) -> Observation: | |
| """Start a new episode. Returns the initial observation.""" | |
| resp = self._client.post( | |
| "/reset", | |
| json={"scenario_id": scenario_id, "seed": seed_override}, | |
| ) | |
| resp.raise_for_status() | |
| return Observation.model_validate(_unwrap_obs(resp.json())) | |
| def step( | |
| self, | |
| episode_id: str, | |
| action: ToolCall | TerminalAction | dict, | |
| ) -> Observation: | |
| """Send one action; returns the next observation.""" | |
| resp = self._client.post( | |
| "/step", | |
| json={"episode_id": episode_id, "action": _to_wire_action(action)}, | |
| ) | |
| resp.raise_for_status() | |
| return Observation.model_validate(_unwrap_obs(resp.json())) | |
| def get_state(self, episode_id: str) -> dict: | |
| """Return raw episode state dict.""" | |
| resp = self._client.get("/state") | |
| resp.raise_for_status() | |
| data = resp.json() | |
| return data.get("payload", data) or {} | |
| def get_trace(self, episode_id: str) -> EpisodeTrace: | |
| """Return the full EpisodeTrace after episode termination.""" | |
| resp = self._client.get(f"/trace/{episode_id}") | |
| resp.raise_for_status() | |
| return EpisodeTrace.model_validate(resp.json()) | |
| def list_tools(self) -> list[dict]: | |
| """Return the MCP tool listing from the server.""" | |
| resp = self._client.post( | |
| "/mcp", | |
| json={"jsonrpc": "2.0", "method": "tools/list", "id": 1}, | |
| ) | |
| resp.raise_for_status() | |
| result = resp.json() | |
| return result.get("result", {}).get("tools", []) | |
| def close(self) -> None: | |
| self._client.close() | |
| def __enter__(self) -> EnvClient: | |
| return self | |
| def __exit__(self, *_) -> None: | |
| self.close() | |