Spaces:
Sleeping
Sleeping
Prasham.Jain Claude Opus 4.7 (1M context) commited on
Commit ·
8be6018
1
Parent(s): 19e2683
feat(branch-a): A1 server scaffold — FastAPI /reset /step /state /mcp + 11 stub tool handlers
Browse files- CITriageEnv with in-memory episode store, scenario loader (disk + hf://), deterministic seeding
- EpisodeManager covering initial obs, tool-call stepping, terminal action handling, state export
- 11 stub ToolHandler subclasses (investigation/context/actions) validating against MCPToolDef.args_schema
- FastAPI surface with /reset, /step, /state/{episode_id}, /mcp/tools (manual MCP listing — PyPI openenv exposes no MCPEnvironment)
- 12 server tests covering boot, reset, step (tool + terminal), state, 404s, concurrency, MCP listing, deterministic seeding
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- pyproject.toml +1 -0
- requirements.txt +6 -1
- src/ci_triage_env/env/episode.py +129 -0
- src/ci_triage_env/env/scenario_loader.py +44 -0
- src/ci_triage_env/env/server.py +172 -0
- src/ci_triage_env/env/tools/__init__.py +48 -0
- src/ci_triage_env/env/tools/actions.py +26 -0
- src/ci_triage_env/env/tools/context.py +21 -0
- src/ci_triage_env/env/tools/investigation.py +70 -0
- tests/env/__init__.py +0 -0
- tests/env/conftest.py +27 -0
- tests/env/test_server.py +155 -0
- uv.lock +14 -0
pyproject.toml
CHANGED
|
@@ -13,6 +13,7 @@ dependencies = [
|
|
| 13 |
"datasets>=2.18",
|
| 14 |
"huggingface_hub>=0.23",
|
| 15 |
"jsonschema>=4.21",
|
|
|
|
| 16 |
]
|
| 17 |
|
| 18 |
[project.optional-dependencies]
|
|
|
|
| 13 |
"datasets>=2.18",
|
| 14 |
"huggingface_hub>=0.23",
|
| 15 |
"jsonschema>=4.21",
|
| 16 |
+
"openenv>=0.1.13",
|
| 17 |
]
|
| 18 |
|
| 19 |
[project.optional-dependencies]
|
requirements.txt
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
Resolved
|
| 2 |
# This file was autogenerated by uv via the following command:
|
| 3 |
# uv export --no-dev --extra training --format requirements-txt
|
| 4 |
-e .
|
|
@@ -546,6 +546,7 @@ numpy==2.4.4 \
|
|
| 546 |
# contourpy
|
| 547 |
# datasets
|
| 548 |
# matplotlib
|
|
|
|
| 549 |
# pandas
|
| 550 |
# seaborn
|
| 551 |
# transformers
|
|
@@ -618,6 +619,10 @@ nvidia-nvtx==13.0.85 ; sys_platform == 'linux' \
|
|
| 618 |
--hash=sha256:4936d1d6780fbe68db454f5e72a42ff64d1fd6397df9f363ae786930fd5c1cd4 \
|
| 619 |
--hash=sha256:cb7780edb6b14107373c835bf8b72e7a178bac7367e23da7acb108f973f157a6
|
| 620 |
# via cuda-toolkit
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
packaging==26.2 \
|
| 622 |
--hash=sha256:5fc45236b9446107ff2415ce77c807cee2862cb6fac22b8a73826d0693b0980e \
|
| 623 |
--hash=sha256:ff452ff5a3e828ce110190feff1178bb1f2ea2281fa2075aadb987c2fb221661
|
|
|
|
| 1 |
+
Resolved 131 packages in 2ms
|
| 2 |
# This file was autogenerated by uv via the following command:
|
| 3 |
# uv export --no-dev --extra training --format requirements-txt
|
| 4 |
-e .
|
|
|
|
| 546 |
# contourpy
|
| 547 |
# datasets
|
| 548 |
# matplotlib
|
| 549 |
+
# openenv
|
| 550 |
# pandas
|
| 551 |
# seaborn
|
| 552 |
# transformers
|
|
|
|
| 619 |
--hash=sha256:4936d1d6780fbe68db454f5e72a42ff64d1fd6397df9f363ae786930fd5c1cd4 \
|
| 620 |
--hash=sha256:cb7780edb6b14107373c835bf8b72e7a178bac7367e23da7acb108f973f157a6
|
| 621 |
# via cuda-toolkit
|
| 622 |
+
openenv==0.1.13 \
|
| 623 |
+
--hash=sha256:726971d2289472c1c20261436bcccdf3edfcf0b201d16aec127815bd83bfcb3d \
|
| 624 |
+
--hash=sha256:813249d7f526f40c6e8b325f705294761a5bc887b9144c3383fa2bae7baa7726
|
| 625 |
+
# via ci-triage-env
|
| 626 |
packaging==26.2 \
|
| 627 |
--hash=sha256:5fc45236b9446107ff2415ce77c807cee2862cb6fac22b8a73826d0693b0980e \
|
| 628 |
--hash=sha256:ff452ff5a3e828ce110190feff1178bb1f2ea2281fa2075aadb987c2fb221661
|
src/ci_triage_env/env/episode.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
from ci_triage_env.schemas.action import TerminalAction, ToolCall
|
| 5 |
+
from ci_triage_env.schemas.episode import EpisodeState, StepRecord
|
| 6 |
+
from ci_triage_env.schemas.observation import BudgetState, Observation, ToolResponse
|
| 7 |
+
from ci_triage_env.schemas.scenario import Scenario, ToolOutput
|
| 8 |
+
|
| 9 |
+
DEFAULT_TOOL_CALL_BUDGET = 12
|
| 10 |
+
DEFAULT_COST_BUDGET = 5.0
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class EpisodeManager:
|
| 15 |
+
"""Owns state for a single in-flight episode.
|
| 16 |
+
|
| 17 |
+
Phase A1 contract: validates lifecycle (initial obs, action stepping, termination,
|
| 18 |
+
state export). Tool *output content* is the responsibility of the handlers wired in
|
| 19 |
+
by the server. Budget enforcement and termination policies tighten in A3.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
scenario: Scenario
|
| 23 |
+
episode_id: str
|
| 24 |
+
seed: int
|
| 25 |
+
|
| 26 |
+
def __post_init__(self) -> None:
|
| 27 |
+
self.step_idx: int = 0
|
| 28 |
+
self.history: list[StepRecord] = []
|
| 29 |
+
self.budget: BudgetState = BudgetState(
|
| 30 |
+
tool_calls_remaining=DEFAULT_TOOL_CALL_BUDGET,
|
| 31 |
+
cost_remaining=DEFAULT_COST_BUDGET,
|
| 32 |
+
)
|
| 33 |
+
self.is_terminated: bool = False
|
| 34 |
+
self.final_action: TerminalAction | None = None
|
| 35 |
+
self._rng = random.Random(self.seed)
|
| 36 |
+
|
| 37 |
+
def initial_observation(self) -> Observation:
|
| 38 |
+
return Observation(
|
| 39 |
+
episode_id=self.episode_id,
|
| 40 |
+
step=0,
|
| 41 |
+
failure_summary=self.scenario.failure_summary,
|
| 42 |
+
tool_response=None,
|
| 43 |
+
budget_remaining=self.budget,
|
| 44 |
+
is_terminal=False,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def derive_step_seed(self, tool_name: str) -> int:
|
| 48 |
+
"""Per-step seed derived from (episode seed, step_idx, tool_name).
|
| 49 |
+
|
| 50 |
+
Tools that internally randomize must use this seed instead of a global RNG.
|
| 51 |
+
"""
|
| 52 |
+
return hash((self.seed, self.step_idx, tool_name)) & 0xFFFFFFFF
|
| 53 |
+
|
| 54 |
+
def apply_tool_call(
|
| 55 |
+
self,
|
| 56 |
+
action: ToolCall,
|
| 57 |
+
output: ToolOutput,
|
| 58 |
+
) -> Observation:
|
| 59 |
+
if self.is_terminated:
|
| 60 |
+
raise RuntimeError("episode already terminated")
|
| 61 |
+
|
| 62 |
+
cost_charged = output.cost_units
|
| 63 |
+
self.budget = BudgetState(
|
| 64 |
+
tool_calls_remaining=max(0, self.budget.tool_calls_remaining - 1),
|
| 65 |
+
cost_remaining=max(0.0, self.budget.cost_remaining - cost_charged),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
observation = Observation(
|
| 69 |
+
episode_id=self.episode_id,
|
| 70 |
+
step=self.step_idx,
|
| 71 |
+
failure_summary=None,
|
| 72 |
+
tool_response=ToolResponse(
|
| 73 |
+
tool_name=action.tool_name,
|
| 74 |
+
args=action.args,
|
| 75 |
+
output=output.payload,
|
| 76 |
+
cost_charged=cost_charged,
|
| 77 |
+
),
|
| 78 |
+
budget_remaining=self.budget,
|
| 79 |
+
is_terminal=False,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
self.history.append(
|
| 83 |
+
StepRecord(
|
| 84 |
+
step=self.step_idx,
|
| 85 |
+
action=action,
|
| 86 |
+
observation=observation,
|
| 87 |
+
cost_charged=cost_charged,
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
self.step_idx += 1
|
| 91 |
+
return observation
|
| 92 |
+
|
| 93 |
+
def apply_terminal(self, action: TerminalAction) -> Observation:
|
| 94 |
+
if self.is_terminated:
|
| 95 |
+
raise RuntimeError("episode already terminated")
|
| 96 |
+
|
| 97 |
+
observation = Observation(
|
| 98 |
+
episode_id=self.episode_id,
|
| 99 |
+
step=self.step_idx,
|
| 100 |
+
failure_summary=None,
|
| 101 |
+
tool_response=None,
|
| 102 |
+
budget_remaining=self.budget,
|
| 103 |
+
is_terminal=True,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
self.history.append(
|
| 107 |
+
StepRecord(
|
| 108 |
+
step=self.step_idx,
|
| 109 |
+
action=action,
|
| 110 |
+
observation=observation,
|
| 111 |
+
cost_charged=0.0,
|
| 112 |
+
)
|
| 113 |
+
)
|
| 114 |
+
self.step_idx += 1
|
| 115 |
+
self.is_terminated = True
|
| 116 |
+
self.final_action = action
|
| 117 |
+
return observation
|
| 118 |
+
|
| 119 |
+
def to_state(self) -> EpisodeState:
|
| 120 |
+
return EpisodeState(
|
| 121 |
+
episode_id=self.episode_id,
|
| 122 |
+
scenario_id=self.scenario.scenario_id,
|
| 123 |
+
seed=self.seed,
|
| 124 |
+
step=self.step_idx,
|
| 125 |
+
history=list(self.history),
|
| 126 |
+
budget=self.budget,
|
| 127 |
+
is_terminated=self.is_terminated,
|
| 128 |
+
final_action=self.final_action,
|
| 129 |
+
)
|
src/ci_triage_env/env/scenario_loader.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from ci_triage_env.schemas.scenario import Scenario
|
| 5 |
+
|
| 6 |
+
DEFAULT_SCENARIO_DIR = Path("data_artifacts/scenarios")
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_from_disk(path: Path) -> dict[str, Scenario]:
|
| 10 |
+
"""Load all *.json files under `path` as Scenario objects, keyed by scenario_id."""
|
| 11 |
+
out: dict[str, Scenario] = {}
|
| 12 |
+
for fp in sorted(path.glob("*.json")):
|
| 13 |
+
scenario = Scenario.model_validate_json(fp.read_text())
|
| 14 |
+
out[scenario.scenario_id] = scenario
|
| 15 |
+
return out
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_from_hf(dataset_name: str) -> dict[str, Scenario]:
|
| 19 |
+
"""Load all rows of an HF dataset as Scenario objects, keyed by scenario_id."""
|
| 20 |
+
from datasets import load_dataset
|
| 21 |
+
|
| 22 |
+
out: dict[str, Scenario] = {}
|
| 23 |
+
ds = load_dataset(dataset_name, split="train")
|
| 24 |
+
for row in ds:
|
| 25 |
+
if isinstance(row, dict) and "scenario_json" in row:
|
| 26 |
+
scenario = Scenario.model_validate_json(row["scenario_json"])
|
| 27 |
+
else:
|
| 28 |
+
scenario = Scenario.model_validate(json.loads(json.dumps(dict(row))))
|
| 29 |
+
out[scenario.scenario_id] = scenario
|
| 30 |
+
return out
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_scenarios(source: str | None) -> dict[str, Scenario]:
|
| 34 |
+
"""Dispatch by source prefix.
|
| 35 |
+
|
| 36 |
+
- None / "" → load from `data_artifacts/scenarios/`.
|
| 37 |
+
- "hf://<name>" → load from HF dataset `<name>`.
|
| 38 |
+
- any other string → treated as a filesystem path.
|
| 39 |
+
"""
|
| 40 |
+
if not source:
|
| 41 |
+
return load_from_disk(DEFAULT_SCENARIO_DIR)
|
| 42 |
+
if source.startswith("hf://"):
|
| 43 |
+
return load_from_hf(source[len("hf://") :])
|
| 44 |
+
return load_from_disk(Path(source))
|
src/ci_triage_env/env/server.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import threading
|
| 5 |
+
import uuid
|
| 6 |
+
|
| 7 |
+
from fastapi import FastAPI, HTTPException
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
|
| 10 |
+
from ci_triage_env.env.episode import EpisodeManager
|
| 11 |
+
from ci_triage_env.env.scenario_loader import load_scenarios
|
| 12 |
+
from ci_triage_env.env.tools import ALL_TOOL_HANDLERS, ToolHandler
|
| 13 |
+
from ci_triage_env.schemas.action import TerminalAction, ToolCall
|
| 14 |
+
from ci_triage_env.schemas.episode import EpisodeState
|
| 15 |
+
from ci_triage_env.schemas.observation import Observation
|
| 16 |
+
from ci_triage_env.schemas.scenario import Scenario
|
| 17 |
+
from ci_triage_env.schemas.tools import ALL_TOOLS, MCPToolDef
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class CITriageEnv:
|
| 23 |
+
"""OpenEnv-style CI triage environment.
|
| 24 |
+
|
| 25 |
+
Public surface: 11 MCP tools + reset/step/state lifecycle. The PyPI ``openenv``
|
| 26 |
+
package does not actually expose ``MCPEnvironment`` (its name collides with an
|
| 27 |
+
unrelated gym-style library); per phase-a1.md "If the path differs, update", we
|
| 28 |
+
implement the MCP listing endpoint directly on FastAPI rather than inheriting.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
scenario_source: str | None = None,
|
| 34 |
+
scenarios: dict[str, Scenario] | None = None,
|
| 35 |
+
):
|
| 36 |
+
self._episodes: dict[str, EpisodeManager] = {}
|
| 37 |
+
self._lock = threading.Lock()
|
| 38 |
+
if scenarios is not None:
|
| 39 |
+
self._scenarios = dict(scenarios)
|
| 40 |
+
else:
|
| 41 |
+
self._scenarios = load_scenarios(scenario_source)
|
| 42 |
+
if not self._scenarios:
|
| 43 |
+
raise RuntimeError(
|
| 44 |
+
"no scenarios found; populate data_artifacts/scenarios/*.json or set "
|
| 45 |
+
"CI_TRIAGE_SCENARIO_SOURCE"
|
| 46 |
+
)
|
| 47 |
+
self._tools: dict[str, ToolHandler] = {h.name: h for h in ALL_TOOL_HANDLERS}
|
| 48 |
+
self._tool_defs: dict[str, MCPToolDef] = {t.name: t for t in ALL_TOOLS}
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def scenarios(self) -> dict[str, Scenario]:
|
| 52 |
+
return self._scenarios
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def tool_defs(self) -> list[MCPToolDef]:
|
| 56 |
+
return list(self._tool_defs.values())
|
| 57 |
+
|
| 58 |
+
def _new_episode_id(self) -> str:
|
| 59 |
+
return str(uuid.uuid4())
|
| 60 |
+
|
| 61 |
+
def _seed_for(self, scenario: Scenario, episode_id: str, override: int | None) -> int:
|
| 62 |
+
if override is not None:
|
| 63 |
+
return override
|
| 64 |
+
return hash((scenario.seed, episode_id)) & 0xFFFFFFFF
|
| 65 |
+
|
| 66 |
+
def reset(
|
| 67 |
+
self,
|
| 68 |
+
scenario_id: str | None = None,
|
| 69 |
+
seed_override: int | None = None,
|
| 70 |
+
) -> Observation:
|
| 71 |
+
if scenario_id is None:
|
| 72 |
+
scenario_id = random.choice(list(self._scenarios.keys()))
|
| 73 |
+
scenario = self._scenarios.get(scenario_id)
|
| 74 |
+
if scenario is None:
|
| 75 |
+
raise KeyError(scenario_id)
|
| 76 |
+
episode_id = self._new_episode_id()
|
| 77 |
+
seed = self._seed_for(scenario, episode_id, seed_override)
|
| 78 |
+
manager = EpisodeManager(scenario=scenario, episode_id=episode_id, seed=seed)
|
| 79 |
+
with self._lock:
|
| 80 |
+
self._episodes[episode_id] = manager
|
| 81 |
+
return manager.initial_observation()
|
| 82 |
+
|
| 83 |
+
def step(self, episode_id: str, action: dict) -> Observation:
|
| 84 |
+
with self._lock:
|
| 85 |
+
manager = self._episodes.get(episode_id)
|
| 86 |
+
if manager is None:
|
| 87 |
+
raise KeyError(episode_id)
|
| 88 |
+
if manager.is_terminated:
|
| 89 |
+
raise RuntimeError("episode already terminated")
|
| 90 |
+
|
| 91 |
+
if action.get("action_type") == "submit_diagnosis":
|
| 92 |
+
terminal = TerminalAction.model_validate(action)
|
| 93 |
+
return manager.apply_terminal(terminal)
|
| 94 |
+
|
| 95 |
+
if "tool_name" in action:
|
| 96 |
+
tool_call = ToolCall.model_validate(action)
|
| 97 |
+
handler = self._tools.get(tool_call.tool_name)
|
| 98 |
+
if handler is None:
|
| 99 |
+
raise KeyError(f"unknown tool: {tool_call.tool_name}")
|
| 100 |
+
output = handler.call(tool_call.args, manager.scenario, manager.history)
|
| 101 |
+
return manager.apply_tool_call(tool_call, output)
|
| 102 |
+
|
| 103 |
+
raise ValueError(
|
| 104 |
+
"action must be a ToolCall (with tool_name) or a TerminalAction "
|
| 105 |
+
"(with action_type='submit_diagnosis')"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def state(self, episode_id: str) -> EpisodeState:
|
| 109 |
+
with self._lock:
|
| 110 |
+
manager = self._episodes.get(episode_id)
|
| 111 |
+
if manager is None:
|
| 112 |
+
raise KeyError(episode_id)
|
| 113 |
+
return manager.to_state()
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class ResetRequest(BaseModel):
|
| 117 |
+
scenario_id: str | None = None
|
| 118 |
+
seed_override: int | None = None
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class StepRequest(BaseModel):
|
| 122 |
+
episode_id: str
|
| 123 |
+
action: dict
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def create_app(env: CITriageEnv) -> FastAPI:
|
| 127 |
+
app = FastAPI(title="CI Triage Env")
|
| 128 |
+
|
| 129 |
+
@app.post("/reset")
|
| 130 |
+
def reset(req: ResetRequest) -> Observation:
|
| 131 |
+
try:
|
| 132 |
+
return env.reset(scenario_id=req.scenario_id, seed_override=req.seed_override)
|
| 133 |
+
except KeyError as exc:
|
| 134 |
+
raise HTTPException(status_code=404, detail=f"unknown scenario_id: {exc.args[0]}") from exc
|
| 135 |
+
|
| 136 |
+
@app.post("/step")
|
| 137 |
+
def step(req: StepRequest) -> Observation:
|
| 138 |
+
try:
|
| 139 |
+
return env.step(req.episode_id, req.action)
|
| 140 |
+
except KeyError as exc:
|
| 141 |
+
raise HTTPException(status_code=404, detail=f"unknown id: {exc.args[0]}") from exc
|
| 142 |
+
except RuntimeError as exc:
|
| 143 |
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 144 |
+
except (ValueError, Exception) as exc:
|
| 145 |
+
if isinstance(exc, HTTPException):
|
| 146 |
+
raise
|
| 147 |
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 148 |
+
|
| 149 |
+
@app.get("/state/{episode_id}")
|
| 150 |
+
def state(episode_id: str) -> EpisodeState:
|
| 151 |
+
try:
|
| 152 |
+
return env.state(episode_id)
|
| 153 |
+
except KeyError as exc:
|
| 154 |
+
raise HTTPException(status_code=404, detail=f"unknown episode_id: {exc.args[0]}") from exc
|
| 155 |
+
|
| 156 |
+
@app.get("/mcp/tools")
|
| 157 |
+
def list_mcp_tools() -> list[MCPToolDef]:
|
| 158 |
+
return env.tool_defs
|
| 159 |
+
|
| 160 |
+
return app
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _bootstrap() -> FastAPI:
|
| 164 |
+
source = os.environ.get("CI_TRIAGE_SCENARIO_SOURCE")
|
| 165 |
+
env = CITriageEnv(scenario_source=source)
|
| 166 |
+
return create_app(env)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if __name__ == "__main__":
|
| 170 |
+
import uvicorn
|
| 171 |
+
|
| 172 |
+
uvicorn.run(_bootstrap(), host="0.0.0.0", port=8000)
|
src/ci_triage_env/env/tools/__init__.py
CHANGED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ci_triage_env.env.tools.actions import (
|
| 2 |
+
FileBugHandler,
|
| 3 |
+
PingOwnerHandler,
|
| 4 |
+
QuarantineTestHandler,
|
| 5 |
+
RerunTestHandler,
|
| 6 |
+
)
|
| 7 |
+
from ci_triage_env.env.tools.base import ToolHandler
|
| 8 |
+
from ci_triage_env.env.tools.context import (
|
| 9 |
+
CheckOwnerHandler,
|
| 10 |
+
QueryFlakeHistoryHandler,
|
| 11 |
+
RecentCommitsHandler,
|
| 12 |
+
)
|
| 13 |
+
from ci_triage_env.env.tools.investigation import (
|
| 14 |
+
ClusterMetricsHandler,
|
| 15 |
+
InspectTestCodeHandler,
|
| 16 |
+
ReadLogsHandler,
|
| 17 |
+
RunDiagnosticHandler,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
ALL_TOOL_HANDLERS: list[ToolHandler] = [
|
| 21 |
+
ReadLogsHandler(),
|
| 22 |
+
InspectTestCodeHandler(),
|
| 23 |
+
RunDiagnosticHandler(),
|
| 24 |
+
ClusterMetricsHandler(),
|
| 25 |
+
QueryFlakeHistoryHandler(),
|
| 26 |
+
RecentCommitsHandler(),
|
| 27 |
+
CheckOwnerHandler(),
|
| 28 |
+
RerunTestHandler(),
|
| 29 |
+
QuarantineTestHandler(),
|
| 30 |
+
FileBugHandler(),
|
| 31 |
+
PingOwnerHandler(),
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
__all__ = [
|
| 35 |
+
"ALL_TOOL_HANDLERS",
|
| 36 |
+
"CheckOwnerHandler",
|
| 37 |
+
"ClusterMetricsHandler",
|
| 38 |
+
"FileBugHandler",
|
| 39 |
+
"InspectTestCodeHandler",
|
| 40 |
+
"PingOwnerHandler",
|
| 41 |
+
"QuarantineTestHandler",
|
| 42 |
+
"QueryFlakeHistoryHandler",
|
| 43 |
+
"ReadLogsHandler",
|
| 44 |
+
"RecentCommitsHandler",
|
| 45 |
+
"RerunTestHandler",
|
| 46 |
+
"RunDiagnosticHandler",
|
| 47 |
+
"ToolHandler",
|
| 48 |
+
]
|
src/ci_triage_env/env/tools/actions.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import ClassVar
|
| 2 |
+
|
| 3 |
+
from ci_triage_env.env.tools.investigation import _StubToolHandler
|
| 4 |
+
from ci_triage_env.schemas.tools import ALL_TOOLS
|
| 5 |
+
|
| 6 |
+
_TOOL_DEFS = {t.name: t for t in ALL_TOOLS}
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class RerunTestHandler(_StubToolHandler):
|
| 10 |
+
name: ClassVar[str] = "rerun_test"
|
| 11 |
+
cost_unit: ClassVar[float] = _TOOL_DEFS["rerun_test"].cost_unit
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class QuarantineTestHandler(_StubToolHandler):
|
| 15 |
+
name: ClassVar[str] = "quarantine_test"
|
| 16 |
+
cost_unit: ClassVar[float] = _TOOL_DEFS["quarantine_test"].cost_unit
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class FileBugHandler(_StubToolHandler):
|
| 20 |
+
name: ClassVar[str] = "file_bug"
|
| 21 |
+
cost_unit: ClassVar[float] = _TOOL_DEFS["file_bug"].cost_unit
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class PingOwnerHandler(_StubToolHandler):
|
| 25 |
+
name: ClassVar[str] = "ping_owner"
|
| 26 |
+
cost_unit: ClassVar[float] = _TOOL_DEFS["ping_owner"].cost_unit
|
src/ci_triage_env/env/tools/context.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import ClassVar
|
| 2 |
+
|
| 3 |
+
from ci_triage_env.env.tools.investigation import _StubToolHandler
|
| 4 |
+
from ci_triage_env.schemas.tools import ALL_TOOLS
|
| 5 |
+
|
| 6 |
+
_TOOL_DEFS = {t.name: t for t in ALL_TOOLS}
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class QueryFlakeHistoryHandler(_StubToolHandler):
|
| 10 |
+
name: ClassVar[str] = "query_flake_history"
|
| 11 |
+
cost_unit: ClassVar[float] = _TOOL_DEFS["query_flake_history"].cost_unit
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RecentCommitsHandler(_StubToolHandler):
|
| 15 |
+
name: ClassVar[str] = "recent_commits"
|
| 16 |
+
cost_unit: ClassVar[float] = _TOOL_DEFS["recent_commits"].cost_unit
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CheckOwnerHandler(_StubToolHandler):
|
| 20 |
+
name: ClassVar[str] = "check_owner"
|
| 21 |
+
cost_unit: ClassVar[float] = _TOOL_DEFS["check_owner"].cost_unit
|
src/ci_triage_env/env/tools/investigation.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import ClassVar
|
| 2 |
+
|
| 3 |
+
import jsonschema
|
| 4 |
+
|
| 5 |
+
from ci_triage_env.env.tools.base import ToolHandler
|
| 6 |
+
from ci_triage_env.schemas.episode import StepRecord
|
| 7 |
+
from ci_triage_env.schemas.scenario import Scenario, ToolOutput
|
| 8 |
+
from ci_triage_env.schemas.tools import ALL_TOOLS
|
| 9 |
+
|
| 10 |
+
_TOOL_DEFS = {t.name: t for t in ALL_TOOLS}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class _StubToolHandler(ToolHandler):
|
| 14 |
+
"""Phase A1 stub. Validates args against MCPToolDef.args_schema, returns placeholder payload."""
|
| 15 |
+
|
| 16 |
+
name: ClassVar[str] = ""
|
| 17 |
+
cost_unit: ClassVar[float] = 0.0
|
| 18 |
+
|
| 19 |
+
def validate_args(self, args: dict) -> None:
|
| 20 |
+
spec = _TOOL_DEFS[self.name]
|
| 21 |
+
try:
|
| 22 |
+
jsonschema.validate(instance=args, schema=spec.args_schema)
|
| 23 |
+
except jsonschema.ValidationError as exc:
|
| 24 |
+
raise ValueError(f"invalid args for {self.name}: {exc.message}") from exc
|
| 25 |
+
|
| 26 |
+
def call(
|
| 27 |
+
self,
|
| 28 |
+
args: dict,
|
| 29 |
+
scenario: Scenario,
|
| 30 |
+
history: list[StepRecord],
|
| 31 |
+
) -> ToolOutput:
|
| 32 |
+
self.validate_args(args)
|
| 33 |
+
return ToolOutput(
|
| 34 |
+
tool_name=self.name,
|
| 35 |
+
payload={"stub": True, "tool": self.name},
|
| 36 |
+
cost_units=self.cost_unit,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ReadLogsHandler(_StubToolHandler):
|
| 41 |
+
name: ClassVar[str] = "read_logs"
|
| 42 |
+
cost_unit: ClassVar[float] = _TOOL_DEFS["read_logs"].cost_unit
|
| 43 |
+
|
| 44 |
+
def call(
|
| 45 |
+
self,
|
| 46 |
+
args: dict,
|
| 47 |
+
scenario: Scenario,
|
| 48 |
+
history: list[StepRecord],
|
| 49 |
+
) -> ToolOutput:
|
| 50 |
+
self.validate_args(args)
|
| 51 |
+
return ToolOutput(
|
| 52 |
+
tool_name=self.name,
|
| 53 |
+
payload={"lines": ["[stub]"], "truncated": False},
|
| 54 |
+
cost_units=self.cost_unit,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class InspectTestCodeHandler(_StubToolHandler):
|
| 59 |
+
name: ClassVar[str] = "inspect_test_code"
|
| 60 |
+
cost_unit: ClassVar[float] = _TOOL_DEFS["inspect_test_code"].cost_unit
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class RunDiagnosticHandler(_StubToolHandler):
|
| 64 |
+
name: ClassVar[str] = "run_diagnostic"
|
| 65 |
+
cost_unit: ClassVar[float] = _TOOL_DEFS["run_diagnostic"].cost_unit
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class ClusterMetricsHandler(_StubToolHandler):
|
| 69 |
+
name: ClassVar[str] = "cluster_metrics"
|
| 70 |
+
cost_unit: ClassVar[float] = _TOOL_DEFS["cluster_metrics"].cost_unit
|
tests/env/__init__.py
ADDED
|
File without changes
|
tests/env/conftest.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from fastapi.testclient import TestClient
|
| 3 |
+
|
| 4 |
+
from ci_triage_env.env.server import CITriageEnv, create_app
|
| 5 |
+
from ci_triage_env.mock.scenario import make_mock_scenario
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@pytest.fixture
|
| 9 |
+
def env() -> CITriageEnv:
|
| 10 |
+
scenarios = {
|
| 11 |
+
s.scenario_id: s
|
| 12 |
+
for s in [
|
| 13 |
+
make_mock_scenario("race_flake", seed=42),
|
| 14 |
+
make_mock_scenario("real_bug", seed=7),
|
| 15 |
+
]
|
| 16 |
+
}
|
| 17 |
+
return CITriageEnv(scenarios=scenarios)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@pytest.fixture
|
| 21 |
+
def client(env: CITriageEnv) -> TestClient:
|
| 22 |
+
return TestClient(create_app(env))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@pytest.fixture
|
| 26 |
+
def known_scenario_id(env: CITriageEnv) -> str:
|
| 27 |
+
return next(iter(env.scenarios))
|
tests/env/test_server.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 2 |
+
|
| 3 |
+
from fastapi.testclient import TestClient
|
| 4 |
+
|
| 5 |
+
from ci_triage_env.env.server import CITriageEnv, create_app
|
| 6 |
+
from ci_triage_env.mock.scenario import make_mock_scenario
|
| 7 |
+
from ci_triage_env.schemas.action import TerminalAction, ToolCall
|
| 8 |
+
from ci_triage_env.schemas.diagnosis import DiagnosisLabel
|
| 9 |
+
from ci_triage_env.schemas.episode import EpisodeState
|
| 10 |
+
from ci_triage_env.schemas.observation import Observation
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_server_boots():
|
| 14 |
+
env = CITriageEnv(scenarios={make_mock_scenario().scenario_id: make_mock_scenario()})
|
| 15 |
+
app = create_app(env)
|
| 16 |
+
assert app.title == "CI Triage Env"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_reset_returns_valid_observation(client: TestClient):
|
| 20 |
+
resp = client.post("/reset", json={})
|
| 21 |
+
assert resp.status_code == 200
|
| 22 |
+
obs = Observation.model_validate(resp.json())
|
| 23 |
+
assert obs.failure_summary is not None
|
| 24 |
+
assert obs.step == 0
|
| 25 |
+
assert obs.is_terminal is False
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def test_reset_with_specific_scenario_id(client: TestClient, known_scenario_id: str):
|
| 29 |
+
resp = client.post("/reset", json={"scenario_id": known_scenario_id})
|
| 30 |
+
assert resp.status_code == 200
|
| 31 |
+
obs = Observation.model_validate(resp.json())
|
| 32 |
+
assert obs.episode_id
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_reset_with_unknown_scenario_id_404(client: TestClient):
|
| 36 |
+
resp = client.post("/reset", json={"scenario_id": "does-not-exist"})
|
| 37 |
+
assert resp.status_code == 404
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_step_with_tool_call_returns_observation(client: TestClient, known_scenario_id: str):
|
| 41 |
+
reset = client.post("/reset", json={"scenario_id": known_scenario_id}).json()
|
| 42 |
+
episode_id = reset["episode_id"]
|
| 43 |
+
call = ToolCall(tool_name="read_logs", args={"scope": "test"})
|
| 44 |
+
resp = client.post("/step", json={"episode_id": episode_id, "action": call.model_dump()})
|
| 45 |
+
assert resp.status_code == 200, resp.text
|
| 46 |
+
obs = Observation.model_validate(resp.json())
|
| 47 |
+
assert obs.tool_response is not None
|
| 48 |
+
assert obs.tool_response.tool_name == "read_logs"
|
| 49 |
+
assert obs.is_terminal is False
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def test_step_with_terminal_action_marks_done(client: TestClient, known_scenario_id: str):
|
| 53 |
+
reset = client.post("/reset", json={"scenario_id": known_scenario_id}).json()
|
| 54 |
+
episode_id = reset["episode_id"]
|
| 55 |
+
terminal = TerminalAction(
|
| 56 |
+
action_type="submit_diagnosis",
|
| 57 |
+
diagnosis=DiagnosisLabel.RACE_FLAKE,
|
| 58 |
+
confidence=0.8,
|
| 59 |
+
)
|
| 60 |
+
resp = client.post("/step", json={"episode_id": episode_id, "action": terminal.model_dump()})
|
| 61 |
+
assert resp.status_code == 200, resp.text
|
| 62 |
+
obs = Observation.model_validate(resp.json())
|
| 63 |
+
assert obs.is_terminal is True
|
| 64 |
+
|
| 65 |
+
state_resp = client.get(f"/state/{episode_id}")
|
| 66 |
+
state = EpisodeState.model_validate(state_resp.json())
|
| 67 |
+
assert state.is_terminated is True
|
| 68 |
+
assert state.final_action is not None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def test_step_after_terminal_returns_400(client: TestClient, known_scenario_id: str):
|
| 72 |
+
reset = client.post("/reset", json={"scenario_id": known_scenario_id}).json()
|
| 73 |
+
episode_id = reset["episode_id"]
|
| 74 |
+
terminal = TerminalAction(
|
| 75 |
+
action_type="submit_diagnosis",
|
| 76 |
+
diagnosis=DiagnosisLabel.RACE_FLAKE,
|
| 77 |
+
confidence=0.8,
|
| 78 |
+
)
|
| 79 |
+
client.post("/step", json={"episode_id": episode_id, "action": terminal.model_dump()})
|
| 80 |
+
again = client.post("/step", json={"episode_id": episode_id, "action": terminal.model_dump()})
|
| 81 |
+
assert again.status_code == 400
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def test_state_endpoint_returns_episode_state(client: TestClient, known_scenario_id: str):
|
| 85 |
+
reset = client.post("/reset", json={"scenario_id": known_scenario_id}).json()
|
| 86 |
+
episode_id = reset["episode_id"]
|
| 87 |
+
resp = client.get(f"/state/{episode_id}")
|
| 88 |
+
assert resp.status_code == 200
|
| 89 |
+
state = EpisodeState.model_validate(resp.json())
|
| 90 |
+
assert state.episode_id == episode_id
|
| 91 |
+
assert state.scenario_id == known_scenario_id
|
| 92 |
+
assert state.step == 0
|
| 93 |
+
assert state.is_terminated is False
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def test_state_unknown_episode_404(client: TestClient):
|
| 97 |
+
resp = client.get("/state/not-a-real-episode-id")
|
| 98 |
+
assert resp.status_code == 404
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def test_concurrent_resets_get_distinct_episode_ids(client: TestClient, known_scenario_id: str):
|
| 102 |
+
def do_reset() -> str:
|
| 103 |
+
return client.post("/reset", json={"scenario_id": known_scenario_id}).json()["episode_id"]
|
| 104 |
+
|
| 105 |
+
with ThreadPoolExecutor(max_workers=8) as pool:
|
| 106 |
+
ids = list(pool.map(lambda _: do_reset(), range(8)))
|
| 107 |
+
|
| 108 |
+
assert len(set(ids)) == len(ids)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def test_mcp_endpoint_lists_all_11_tools(client: TestClient):
|
| 112 |
+
resp = client.get("/mcp/tools")
|
| 113 |
+
assert resp.status_code == 200
|
| 114 |
+
tools = resp.json()
|
| 115 |
+
names = {t["name"] for t in tools}
|
| 116 |
+
assert names == {
|
| 117 |
+
"read_logs",
|
| 118 |
+
"inspect_test_code",
|
| 119 |
+
"run_diagnostic",
|
| 120 |
+
"cluster_metrics",
|
| 121 |
+
"query_flake_history",
|
| 122 |
+
"recent_commits",
|
| 123 |
+
"check_owner",
|
| 124 |
+
"rerun_test",
|
| 125 |
+
"quarantine_test",
|
| 126 |
+
"file_bug",
|
| 127 |
+
"ping_owner",
|
| 128 |
+
}
|
| 129 |
+
assert len(tools) == 11
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def test_episode_seeding_deterministic(client: TestClient, known_scenario_id: str):
|
| 133 |
+
def run_one() -> EpisodeState:
|
| 134 |
+
reset = client.post(
|
| 135 |
+
"/reset",
|
| 136 |
+
json={"scenario_id": known_scenario_id, "seed_override": 12345},
|
| 137 |
+
).json()
|
| 138 |
+
episode_id = reset["episode_id"]
|
| 139 |
+
call = ToolCall(tool_name="read_logs", args={"scope": "test"})
|
| 140 |
+
client.post("/step", json={"episode_id": episode_id, "action": call.model_dump()})
|
| 141 |
+
terminal = TerminalAction(
|
| 142 |
+
action_type="submit_diagnosis",
|
| 143 |
+
diagnosis=DiagnosisLabel.RACE_FLAKE,
|
| 144 |
+
confidence=0.6,
|
| 145 |
+
)
|
| 146 |
+
client.post("/step", json={"episode_id": episode_id, "action": terminal.model_dump()})
|
| 147 |
+
return EpisodeState.model_validate(client.get(f"/state/{episode_id}").json())
|
| 148 |
+
|
| 149 |
+
a = run_one()
|
| 150 |
+
b = run_one()
|
| 151 |
+
assert a.seed == b.seed == 12345
|
| 152 |
+
assert a.step == b.step
|
| 153 |
+
assert a.is_terminated and b.is_terminated
|
| 154 |
+
assert [r.action for r in a.history] == [r.action for r in b.history]
|
| 155 |
+
assert [r.cost_charged for r in a.history] == [r.cost_charged for r in b.history]
|
uv.lock
CHANGED
|
@@ -198,6 +198,7 @@ dependencies = [
|
|
| 198 |
{ name = "httpx" },
|
| 199 |
{ name = "huggingface-hub" },
|
| 200 |
{ name = "jsonschema" },
|
|
|
|
| 201 |
{ name = "pydantic" },
|
| 202 |
{ name = "pyyaml" },
|
| 203 |
{ name = "uvicorn", extra = ["standard"] },
|
|
@@ -239,6 +240,7 @@ requires-dist = [
|
|
| 239 |
{ name = "matplotlib", marker = "extra == 'training'", specifier = ">=3.8" },
|
| 240 |
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.10" },
|
| 241 |
{ name = "openai", marker = "extra == 'data'", specifier = ">=1.40" },
|
|
|
|
| 242 |
{ name = "pandas", marker = "extra == 'training'", specifier = ">=2.2" },
|
| 243 |
{ name = "pydantic", specifier = ">=2.7,<3.0" },
|
| 244 |
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8" },
|
|
@@ -1304,6 +1306,18 @@ wheels = [
|
|
| 1304 |
{ url = "https://files.pythonhosted.org/packages/1e/c1/d6e64ccd0536bf616556f0cad2b6d94a8125f508d25cfd814b1d2db4e2f1/openai-2.32.0-py3-none-any.whl", hash = "sha256:4dcc9badeb4bf54ad0d187453742f290226d30150890b7890711bda4f32f192f", size = 1162570, upload-time = "2026-04-15T22:28:17.714Z" },
|
| 1305 |
]
|
| 1306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1307 |
[[package]]
|
| 1308 |
name = "packaging"
|
| 1309 |
version = "26.2"
|
|
|
|
| 198 |
{ name = "httpx" },
|
| 199 |
{ name = "huggingface-hub" },
|
| 200 |
{ name = "jsonschema" },
|
| 201 |
+
{ name = "openenv" },
|
| 202 |
{ name = "pydantic" },
|
| 203 |
{ name = "pyyaml" },
|
| 204 |
{ name = "uvicorn", extra = ["standard"] },
|
|
|
|
| 240 |
{ name = "matplotlib", marker = "extra == 'training'", specifier = ">=3.8" },
|
| 241 |
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.10" },
|
| 242 |
{ name = "openai", marker = "extra == 'data'", specifier = ">=1.40" },
|
| 243 |
+
{ name = "openenv", specifier = ">=0.1.13" },
|
| 244 |
{ name = "pandas", marker = "extra == 'training'", specifier = ">=2.2" },
|
| 245 |
{ name = "pydantic", specifier = ">=2.7,<3.0" },
|
| 246 |
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8" },
|
|
|
|
| 1306 |
{ url = "https://files.pythonhosted.org/packages/1e/c1/d6e64ccd0536bf616556f0cad2b6d94a8125f508d25cfd814b1d2db4e2f1/openai-2.32.0-py3-none-any.whl", hash = "sha256:4dcc9badeb4bf54ad0d187453742f290226d30150890b7890711bda4f32f192f", size = 1162570, upload-time = "2026-04-15T22:28:17.714Z" },
|
| 1307 |
]
|
| 1308 |
|
| 1309 |
+
[[package]]
|
| 1310 |
+
name = "openenv"
|
| 1311 |
+
version = "0.1.13"
|
| 1312 |
+
source = { registry = "https://pypi.org/simple" }
|
| 1313 |
+
dependencies = [
|
| 1314 |
+
{ name = "numpy" },
|
| 1315 |
+
]
|
| 1316 |
+
sdist = { url = "https://files.pythonhosted.org/packages/35/94/c47e8f7303452793a3519c8cbc1b31dfffdedd13aaed821958ab3f152927/openenv-0.1.13.tar.gz", hash = "sha256:726971d2289472c1c20261436bcccdf3edfcf0b201d16aec127815bd83bfcb3d", size = 5112, upload-time = "2020-12-16T11:49:39.777Z" }
|
| 1317 |
+
wheels = [
|
| 1318 |
+
{ url = "https://files.pythonhosted.org/packages/33/7f/e6f4467528161b8f0eb2ec784f4bbcd1fa9ea7acad13c0fb18597013e83b/openenv-0.1.13-py3-none-any.whl", hash = "sha256:813249d7f526f40c6e8b325f705294761a5bc887b9144c3383fa2bae7baa7726", size = 12080, upload-time = "2020-12-16T11:49:38.816Z" },
|
| 1319 |
+
]
|
| 1320 |
+
|
| 1321 |
[[package]]
|
| 1322 |
name = "packaging"
|
| 1323 |
version = "26.2"
|