Spaces:
Sleeping
Sleeping
File size: 10,238 Bytes
4c1ad51 8be6018 4c1ad51 8be6018 4c1ad51 8be6018 ed51f28 8be6018 ed51f28 4c1ad51 8be6018 4c1ad51 8be6018 4c1ad51 8be6018 4c1ad51 8be6018 4c1ad51 8be6018 4c1ad51 8be6018 4c1ad51 8be6018 4c1ad51 8be6018 4c1ad51 8be6018 4c1ad51 8be6018 4c1ad51 8be6018 4c1ad51 8be6018 4c1ad51 8be6018 4c1ad51 8be6018 4c1ad51 ed51f28 8be6018 4c1ad51 ed51f28 4c1ad51 ed51f28 8be6018 4c1ad51 8be6018 4c1ad51 8be6018 4c1ad51 8be6018 e1814ef 4c1ad51 8be6018 4c1ad51 e1814ef 4c1ad51 ba93ec0 4c1ad51 e1814ef 4c1ad51 ba93ec0 e1814ef 8be6018 4c1ad51 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 | """OpenEnv-compliant CI triage environment.
Public surface (registered automatically by ``openenv.core.env_server.http_server.create_app``):
- ``POST /reset`` β ``ResetResponse`` with the initial domain Observation
- ``POST /step`` β ``StepResponse`` for ``CITriageAction``
- ``GET /state`` β ``CITriageState`` (current EpisodeState)
- ``GET /schema`` β action / observation / state JSON Schemas
- ``POST /mcp`` β JSON-RPC ``tools/list`` and ``tools/call``
- ``WS /mcp`` β MCP over WebSocket
- ``WS /ws`` β persistent OpenEnv session (canonical ``EnvClient`` path)
- ``GET /docs`` β Swagger UI
Tools are registered with a ``FastMCP`` server so MCP clients can discover them
via ``tools/list``. Step-level tool calls go through ``CITriageAction``
(``kind="tool_call"``) so budget/cost bookkeeping stays inside the env.
"""
from __future__ import annotations
import logging
import os
import uuid
from typing import Any
from fastmcp import FastMCP
from openenv.core.env_server.http_server import create_app as openenv_create_app
from openenv.core.env_server.mcp_environment import MCPEnvironment
from ci_triage_env.env.episode import EpisodeManager, EpisodeTerminatedError
from ci_triage_env.env.scenario_loader import load_scenarios
from ci_triage_env.env.tools import ALL_TOOL_HANDLERS, ToolHandler
from ci_triage_env.env.trace import trace_dir, write_trace
from ci_triage_env.env.wire import CITriageAction, CITriageObservation, CITriageState
from ci_triage_env.schemas.observation import Observation as DomainObservation
from ci_triage_env.schemas.scenario import Scenario, ToolOutput
from ci_triage_env.schemas.tools import ALL_TOOLS
logger = logging.getLogger(__name__)
def _build_mcp_server(handlers: list[ToolHandler]) -> FastMCP:
"""Build a FastMCP server advertising the 11 tools.
The function bodies are stubs in Phase A1: they validate args against the
schema and return placeholder payloads. Real routing into scenario-specific
tool outputs lands in A2. Cost / budget bookkeeping happens on the env's
``/step`` path via ``CITriageAction(kind="tool_call")``, not here, so that
one canonical path owns episode state.
"""
mcp = FastMCP(name="ci-triage")
handler_map = {h.name: h for h in handlers}
tool_defs = {t.name: t for t in ALL_TOOLS}
for tool_def in ALL_TOOLS:
handler = handler_map[tool_def.name]
def _make_fn(name: str, h: ToolHandler) -> Any:
def _fn(arguments: dict | None = None) -> dict:
args = arguments or {}
h.validate_args(args)
placeholder = ToolOutput(
tool_name=name,
payload={"stub": True, "tool": name},
cost_units=h.cost_unit,
)
return {"payload": placeholder.payload, "cost_units": placeholder.cost_units}
_fn.__name__ = name
return _fn
mcp.tool(
name=tool_def.name,
description=tool_def.description,
)(_make_fn(tool_def.name, handler))
# Sanity: ALL_TOOLS already excludes reserved names; the MCPEnvironment will
# re-validate in __init__ but assert here too for fast feedback.
reserved = {"reset", "step", "state", "close"}
assert not (set(tool_defs) & reserved)
return mcp
class CITriageEnv(MCPEnvironment):
"""CI triage env exposing 11 MCP tools and a structured terminal action."""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(
self,
scenario_source: str | None = None,
scenarios: dict[str, Scenario] | None = None,
) -> None:
if scenarios is not None:
self._scenarios: dict[str, Scenario] = dict(scenarios)
else:
source = scenario_source or os.environ.get("CI_TRIAGE_SCENARIO_SOURCE")
self._scenarios = load_scenarios(source)
if not self._scenarios:
raise RuntimeError(
"no scenarios found; populate data_artifacts/scenarios/*.json or set "
"CI_TRIAGE_SCENARIO_SOURCE"
)
self._handlers: dict[str, ToolHandler] = {h.name: h for h in ALL_TOOL_HANDLERS}
super().__init__(mcp_server=_build_mcp_server(ALL_TOOL_HANDLERS))
self._episode: EpisodeManager | None = None
@property
def scenarios(self) -> dict[str, Scenario]:
return self._scenarios
def _seed_for(self, scenario: Scenario, episode_id: str, override: int | None) -> int:
if override is not None:
return override
return hash((scenario.seed, episode_id)) & 0xFFFFFFFF
def _wrap(self, domain_obs: DomainObservation) -> CITriageObservation:
return CITriageObservation(
done=domain_obs.is_terminal,
reward=None,
metadata={},
payload=domain_obs,
)
def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
scenario_id: str | None = None,
**kwargs: Any,
) -> CITriageObservation:
if scenario_id is None:
scenario_id = next(iter(self._scenarios))
scenario = self._scenarios.get(scenario_id)
if scenario is None:
raise KeyError(f"unknown scenario_id: {scenario_id}")
if episode_id is None:
episode_id = str(uuid.uuid4())
self._episode = EpisodeManager(
scenario=scenario,
episode_id=episode_id,
seed=self._seed_for(scenario, episode_id, seed),
)
return self._wrap(self._episode.initial_observation())
def step(
self,
action: CITriageAction,
timeout_s: float | None = None,
**kwargs: Any,
) -> CITriageObservation:
# Override MCPEnvironment.step so /step always lands here for our typed
# action. ListToolsAction / CallToolAction still flow through MCP via
# the /mcp endpoint; clients hitting /step always send CITriageAction.
return self._step_impl(action, timeout_s=timeout_s, **kwargs)
def _step_impl(
self,
action: CITriageAction,
timeout_s: float | None = None,
**kwargs: Any,
) -> CITriageObservation:
if self._episode is None:
raise RuntimeError("env has no active episode; call reset first")
if self._episode.is_terminated:
raise EpisodeTerminatedError("episode already terminated")
if action.kind == "tool_call":
if action.tool_call is None:
raise ValueError("kind='tool_call' requires a tool_call payload")
handler = self._handlers.get(action.tool_call.tool_name)
if handler is None:
raise KeyError(f"unknown tool: {action.tool_call.tool_name}")
domain_obs = self._episode.apply_tool_call(action.tool_call, handler)
elif action.kind == "submit_diagnosis":
if action.terminal is None:
raise ValueError("kind='submit_diagnosis' requires a terminal payload")
domain_obs = self._episode.apply_terminal(action.terminal)
else:
raise ValueError(f"unknown action kind: {action.kind}")
if domain_obs.is_terminal:
try:
write_trace(self._episode, trace_dir())
except OSError:
# Trace write is best-effort: never fail an episode because the
# trace dir is read-only or full. The caller can still inspect
# state via /state.
pass
return self._wrap(domain_obs)
@property
def state(self) -> CITriageState:
if self._episode is None:
return CITriageState(
episode_id=None,
step_count=0,
payload=None,
)
domain_state = self._episode.to_state()
return CITriageState(
episode_id=domain_state.episode_id,
step_count=domain_state.step,
payload=domain_state,
)
def close(self) -> None:
# Inherit MCPEnvironment.close behavior (clears mcp client/server). We
# also drop the episode reference so a stale env can't be re-used.
self._episode = None
super().close()
# ---------------------------------------------------------------------------
# Module-level FastAPI app for `python -m ci_triage_env.env.server` / uvicorn.
# Use ``build_app()`` from tests to inject scenarios.
# ---------------------------------------------------------------------------
def build_app(env_factory=None, mount_visualizer: bool = True):
"""Build the OpenEnv-canonical FastAPI app.
Args:
env_factory: Callable returning a fresh ``CITriageEnv`` per HTTP request
(OpenEnv's stateless HTTP semantics) and per WebSocket session.
Defaults to ``CITriageEnv`` (loads scenarios from
``CI_TRIAGE_SCENARIO_SOURCE`` or ``data_artifacts/scenarios``).
mount_visualizer: If True (default), mount the static replay viewer
under ``/viz``. Set False for headless deployments.
"""
from fastapi import HTTPException
from fastapi.responses import JSONResponse
factory = env_factory or CITriageEnv
app = openenv_create_app(
env=factory,
action_cls=CITriageAction,
observation_cls=CITriageObservation,
env_name="ci-triage",
max_concurrent_envs=4,
)
@app.get("/trace/{episode_id}")
async def get_trace(episode_id: str) -> JSONResponse:
"""Return the serialised EpisodeTrace written by the env after termination."""
td = trace_dir()
trace_path = td / f"{episode_id}.json"
if not trace_path.exists():
raise HTTPException(status_code=404, detail=f"trace not found: {episode_id}")
return JSONResponse(content=__import__("json").loads(trace_path.read_text()))
if mount_visualizer:
from ci_triage_env.visualizer.server import build_visualizer_app
app.mount("/viz", build_visualizer_app(), name="viz")
return app
if __name__ == "__main__":
import uvicorn
uvicorn.run(build_app(), host="0.0.0.0", port=8000)
|