Prasham.Jain Claude Opus 4.7 (1M context) commited on
Commit
8be6018
·
1 Parent(s): 19e2683

feat(branch-a): A1 server scaffold — FastAPI /reset /step /state /mcp + 11 stub tool handlers

Browse files

- CITriageEnv with in-memory episode store, scenario loader (disk + hf://), deterministic seeding
- EpisodeManager covering initial obs, tool-call stepping, terminal action handling, state export
- 11 stub ToolHandler subclasses (investigation/context/actions) validating against MCPToolDef.args_schema
- FastAPI surface with /reset, /step, /state/{episode_id}, /mcp/tools (manual MCP listing — PyPI openenv exposes no MCPEnvironment)
- 12 server tests covering boot, reset, step (tool + terminal), state, 404s, concurrency, MCP listing, deterministic seeding

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

pyproject.toml CHANGED
@@ -13,6 +13,7 @@ dependencies = [
13
  "datasets>=2.18",
14
  "huggingface_hub>=0.23",
15
  "jsonschema>=4.21",
 
16
  ]
17
 
18
  [project.optional-dependencies]
 
13
  "datasets>=2.18",
14
  "huggingface_hub>=0.23",
15
  "jsonschema>=4.21",
16
+ "openenv>=0.1.13",
17
  ]
18
 
19
  [project.optional-dependencies]
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- Resolved 130 packages in 2ms
2
  # This file was autogenerated by uv via the following command:
3
  # uv export --no-dev --extra training --format requirements-txt
4
  -e .
@@ -546,6 +546,7 @@ numpy==2.4.4 \
546
  # contourpy
547
  # datasets
548
  # matplotlib
 
549
  # pandas
550
  # seaborn
551
  # transformers
@@ -618,6 +619,10 @@ nvidia-nvtx==13.0.85 ; sys_platform == 'linux' \
618
  --hash=sha256:4936d1d6780fbe68db454f5e72a42ff64d1fd6397df9f363ae786930fd5c1cd4 \
619
  --hash=sha256:cb7780edb6b14107373c835bf8b72e7a178bac7367e23da7acb108f973f157a6
620
  # via cuda-toolkit
 
 
 
 
621
  packaging==26.2 \
622
  --hash=sha256:5fc45236b9446107ff2415ce77c807cee2862cb6fac22b8a73826d0693b0980e \
623
  --hash=sha256:ff452ff5a3e828ce110190feff1178bb1f2ea2281fa2075aadb987c2fb221661
 
1
+ Resolved 131 packages in 2ms
2
  # This file was autogenerated by uv via the following command:
3
  # uv export --no-dev --extra training --format requirements-txt
4
  -e .
 
546
  # contourpy
547
  # datasets
548
  # matplotlib
549
+ # openenv
550
  # pandas
551
  # seaborn
552
  # transformers
 
619
  --hash=sha256:4936d1d6780fbe68db454f5e72a42ff64d1fd6397df9f363ae786930fd5c1cd4 \
620
  --hash=sha256:cb7780edb6b14107373c835bf8b72e7a178bac7367e23da7acb108f973f157a6
621
  # via cuda-toolkit
622
+ openenv==0.1.13 \
623
+ --hash=sha256:726971d2289472c1c20261436bcccdf3edfcf0b201d16aec127815bd83bfcb3d \
624
+ --hash=sha256:813249d7f526f40c6e8b325f705294761a5bc887b9144c3383fa2bae7baa7726
625
+ # via ci-triage-env
626
  packaging==26.2 \
627
  --hash=sha256:5fc45236b9446107ff2415ce77c807cee2862cb6fac22b8a73826d0693b0980e \
628
  --hash=sha256:ff452ff5a3e828ce110190feff1178bb1f2ea2281fa2075aadb987c2fb221661
src/ci_triage_env/env/episode.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass
3
+
4
+ from ci_triage_env.schemas.action import TerminalAction, ToolCall
5
+ from ci_triage_env.schemas.episode import EpisodeState, StepRecord
6
+ from ci_triage_env.schemas.observation import BudgetState, Observation, ToolResponse
7
+ from ci_triage_env.schemas.scenario import Scenario, ToolOutput
8
+
9
+ DEFAULT_TOOL_CALL_BUDGET = 12
10
+ DEFAULT_COST_BUDGET = 5.0
11
+
12
+
13
+ @dataclass
14
+ class EpisodeManager:
15
+ """Owns state for a single in-flight episode.
16
+
17
+ Phase A1 contract: validates lifecycle (initial obs, action stepping, termination,
18
+ state export). Tool *output content* is the responsibility of the handlers wired in
19
+ by the server. Budget enforcement and termination policies tighten in A3.
20
+ """
21
+
22
+ scenario: Scenario
23
+ episode_id: str
24
+ seed: int
25
+
26
+ def __post_init__(self) -> None:
27
+ self.step_idx: int = 0
28
+ self.history: list[StepRecord] = []
29
+ self.budget: BudgetState = BudgetState(
30
+ tool_calls_remaining=DEFAULT_TOOL_CALL_BUDGET,
31
+ cost_remaining=DEFAULT_COST_BUDGET,
32
+ )
33
+ self.is_terminated: bool = False
34
+ self.final_action: TerminalAction | None = None
35
+ self._rng = random.Random(self.seed)
36
+
37
+ def initial_observation(self) -> Observation:
38
+ return Observation(
39
+ episode_id=self.episode_id,
40
+ step=0,
41
+ failure_summary=self.scenario.failure_summary,
42
+ tool_response=None,
43
+ budget_remaining=self.budget,
44
+ is_terminal=False,
45
+ )
46
+
47
+ def derive_step_seed(self, tool_name: str) -> int:
48
+ """Per-step seed derived from (episode seed, step_idx, tool_name).
49
+
50
+ Tools that internally randomize must use this seed instead of a global RNG.
51
+ """
52
+ return hash((self.seed, self.step_idx, tool_name)) & 0xFFFFFFFF
53
+
54
+ def apply_tool_call(
55
+ self,
56
+ action: ToolCall,
57
+ output: ToolOutput,
58
+ ) -> Observation:
59
+ if self.is_terminated:
60
+ raise RuntimeError("episode already terminated")
61
+
62
+ cost_charged = output.cost_units
63
+ self.budget = BudgetState(
64
+ tool_calls_remaining=max(0, self.budget.tool_calls_remaining - 1),
65
+ cost_remaining=max(0.0, self.budget.cost_remaining - cost_charged),
66
+ )
67
+
68
+ observation = Observation(
69
+ episode_id=self.episode_id,
70
+ step=self.step_idx,
71
+ failure_summary=None,
72
+ tool_response=ToolResponse(
73
+ tool_name=action.tool_name,
74
+ args=action.args,
75
+ output=output.payload,
76
+ cost_charged=cost_charged,
77
+ ),
78
+ budget_remaining=self.budget,
79
+ is_terminal=False,
80
+ )
81
+
82
+ self.history.append(
83
+ StepRecord(
84
+ step=self.step_idx,
85
+ action=action,
86
+ observation=observation,
87
+ cost_charged=cost_charged,
88
+ )
89
+ )
90
+ self.step_idx += 1
91
+ return observation
92
+
93
+ def apply_terminal(self, action: TerminalAction) -> Observation:
94
+ if self.is_terminated:
95
+ raise RuntimeError("episode already terminated")
96
+
97
+ observation = Observation(
98
+ episode_id=self.episode_id,
99
+ step=self.step_idx,
100
+ failure_summary=None,
101
+ tool_response=None,
102
+ budget_remaining=self.budget,
103
+ is_terminal=True,
104
+ )
105
+
106
+ self.history.append(
107
+ StepRecord(
108
+ step=self.step_idx,
109
+ action=action,
110
+ observation=observation,
111
+ cost_charged=0.0,
112
+ )
113
+ )
114
+ self.step_idx += 1
115
+ self.is_terminated = True
116
+ self.final_action = action
117
+ return observation
118
+
119
+ def to_state(self) -> EpisodeState:
120
+ return EpisodeState(
121
+ episode_id=self.episode_id,
122
+ scenario_id=self.scenario.scenario_id,
123
+ seed=self.seed,
124
+ step=self.step_idx,
125
+ history=list(self.history),
126
+ budget=self.budget,
127
+ is_terminated=self.is_terminated,
128
+ final_action=self.final_action,
129
+ )
src/ci_triage_env/env/scenario_loader.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ from ci_triage_env.schemas.scenario import Scenario
5
+
6
+ DEFAULT_SCENARIO_DIR = Path("data_artifacts/scenarios")
7
+
8
+
9
+ def load_from_disk(path: Path) -> dict[str, Scenario]:
10
+ """Load all *.json files under `path` as Scenario objects, keyed by scenario_id."""
11
+ out: dict[str, Scenario] = {}
12
+ for fp in sorted(path.glob("*.json")):
13
+ scenario = Scenario.model_validate_json(fp.read_text())
14
+ out[scenario.scenario_id] = scenario
15
+ return out
16
+
17
+
18
+ def load_from_hf(dataset_name: str) -> dict[str, Scenario]:
19
+ """Load all rows of an HF dataset as Scenario objects, keyed by scenario_id."""
20
+ from datasets import load_dataset
21
+
22
+ out: dict[str, Scenario] = {}
23
+ ds = load_dataset(dataset_name, split="train")
24
+ for row in ds:
25
+ if isinstance(row, dict) and "scenario_json" in row:
26
+ scenario = Scenario.model_validate_json(row["scenario_json"])
27
+ else:
28
+ scenario = Scenario.model_validate(json.loads(json.dumps(dict(row))))
29
+ out[scenario.scenario_id] = scenario
30
+ return out
31
+
32
+
33
+ def load_scenarios(source: str | None) -> dict[str, Scenario]:
34
+ """Dispatch by source prefix.
35
+
36
+ - None / "" → load from `data_artifacts/scenarios/`.
37
+ - "hf://<name>" → load from HF dataset `<name>`.
38
+ - any other string → treated as a filesystem path.
39
+ """
40
+ if not source:
41
+ return load_from_disk(DEFAULT_SCENARIO_DIR)
42
+ if source.startswith("hf://"):
43
+ return load_from_hf(source[len("hf://") :])
44
+ return load_from_disk(Path(source))
src/ci_triage_env/env/server.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import threading
5
+ import uuid
6
+
7
+ from fastapi import FastAPI, HTTPException
8
+ from pydantic import BaseModel
9
+
10
+ from ci_triage_env.env.episode import EpisodeManager
11
+ from ci_triage_env.env.scenario_loader import load_scenarios
12
+ from ci_triage_env.env.tools import ALL_TOOL_HANDLERS, ToolHandler
13
+ from ci_triage_env.schemas.action import TerminalAction, ToolCall
14
+ from ci_triage_env.schemas.episode import EpisodeState
15
+ from ci_triage_env.schemas.observation import Observation
16
+ from ci_triage_env.schemas.scenario import Scenario
17
+ from ci_triage_env.schemas.tools import ALL_TOOLS, MCPToolDef
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class CITriageEnv:
23
+ """OpenEnv-style CI triage environment.
24
+
25
+ Public surface: 11 MCP tools + reset/step/state lifecycle. The PyPI ``openenv``
26
+ package does not actually expose ``MCPEnvironment`` (its name collides with an
27
+ unrelated gym-style library); per phase-a1.md "If the path differs, update", we
28
+ implement the MCP listing endpoint directly on FastAPI rather than inheriting.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ scenario_source: str | None = None,
34
+ scenarios: dict[str, Scenario] | None = None,
35
+ ):
36
+ self._episodes: dict[str, EpisodeManager] = {}
37
+ self._lock = threading.Lock()
38
+ if scenarios is not None:
39
+ self._scenarios = dict(scenarios)
40
+ else:
41
+ self._scenarios = load_scenarios(scenario_source)
42
+ if not self._scenarios:
43
+ raise RuntimeError(
44
+ "no scenarios found; populate data_artifacts/scenarios/*.json or set "
45
+ "CI_TRIAGE_SCENARIO_SOURCE"
46
+ )
47
+ self._tools: dict[str, ToolHandler] = {h.name: h for h in ALL_TOOL_HANDLERS}
48
+ self._tool_defs: dict[str, MCPToolDef] = {t.name: t for t in ALL_TOOLS}
49
+
50
+ @property
51
+ def scenarios(self) -> dict[str, Scenario]:
52
+ return self._scenarios
53
+
54
+ @property
55
+ def tool_defs(self) -> list[MCPToolDef]:
56
+ return list(self._tool_defs.values())
57
+
58
+ def _new_episode_id(self) -> str:
59
+ return str(uuid.uuid4())
60
+
61
+ def _seed_for(self, scenario: Scenario, episode_id: str, override: int | None) -> int:
62
+ if override is not None:
63
+ return override
64
+ return hash((scenario.seed, episode_id)) & 0xFFFFFFFF
65
+
66
+ def reset(
67
+ self,
68
+ scenario_id: str | None = None,
69
+ seed_override: int | None = None,
70
+ ) -> Observation:
71
+ if scenario_id is None:
72
+ scenario_id = random.choice(list(self._scenarios.keys()))
73
+ scenario = self._scenarios.get(scenario_id)
74
+ if scenario is None:
75
+ raise KeyError(scenario_id)
76
+ episode_id = self._new_episode_id()
77
+ seed = self._seed_for(scenario, episode_id, seed_override)
78
+ manager = EpisodeManager(scenario=scenario, episode_id=episode_id, seed=seed)
79
+ with self._lock:
80
+ self._episodes[episode_id] = manager
81
+ return manager.initial_observation()
82
+
83
+ def step(self, episode_id: str, action: dict) -> Observation:
84
+ with self._lock:
85
+ manager = self._episodes.get(episode_id)
86
+ if manager is None:
87
+ raise KeyError(episode_id)
88
+ if manager.is_terminated:
89
+ raise RuntimeError("episode already terminated")
90
+
91
+ if action.get("action_type") == "submit_diagnosis":
92
+ terminal = TerminalAction.model_validate(action)
93
+ return manager.apply_terminal(terminal)
94
+
95
+ if "tool_name" in action:
96
+ tool_call = ToolCall.model_validate(action)
97
+ handler = self._tools.get(tool_call.tool_name)
98
+ if handler is None:
99
+ raise KeyError(f"unknown tool: {tool_call.tool_name}")
100
+ output = handler.call(tool_call.args, manager.scenario, manager.history)
101
+ return manager.apply_tool_call(tool_call, output)
102
+
103
+ raise ValueError(
104
+ "action must be a ToolCall (with tool_name) or a TerminalAction "
105
+ "(with action_type='submit_diagnosis')"
106
+ )
107
+
108
+ def state(self, episode_id: str) -> EpisodeState:
109
+ with self._lock:
110
+ manager = self._episodes.get(episode_id)
111
+ if manager is None:
112
+ raise KeyError(episode_id)
113
+ return manager.to_state()
114
+
115
+
116
+ class ResetRequest(BaseModel):
117
+ scenario_id: str | None = None
118
+ seed_override: int | None = None
119
+
120
+
121
+ class StepRequest(BaseModel):
122
+ episode_id: str
123
+ action: dict
124
+
125
+
126
+ def create_app(env: CITriageEnv) -> FastAPI:
127
+ app = FastAPI(title="CI Triage Env")
128
+
129
+ @app.post("/reset")
130
+ def reset(req: ResetRequest) -> Observation:
131
+ try:
132
+ return env.reset(scenario_id=req.scenario_id, seed_override=req.seed_override)
133
+ except KeyError as exc:
134
+ raise HTTPException(status_code=404, detail=f"unknown scenario_id: {exc.args[0]}") from exc
135
+
136
+ @app.post("/step")
137
+ def step(req: StepRequest) -> Observation:
138
+ try:
139
+ return env.step(req.episode_id, req.action)
140
+ except KeyError as exc:
141
+ raise HTTPException(status_code=404, detail=f"unknown id: {exc.args[0]}") from exc
142
+ except RuntimeError as exc:
143
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
144
+ except (ValueError, Exception) as exc:
145
+ if isinstance(exc, HTTPException):
146
+ raise
147
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
148
+
149
+ @app.get("/state/{episode_id}")
150
+ def state(episode_id: str) -> EpisodeState:
151
+ try:
152
+ return env.state(episode_id)
153
+ except KeyError as exc:
154
+ raise HTTPException(status_code=404, detail=f"unknown episode_id: {exc.args[0]}") from exc
155
+
156
+ @app.get("/mcp/tools")
157
+ def list_mcp_tools() -> list[MCPToolDef]:
158
+ return env.tool_defs
159
+
160
+ return app
161
+
162
+
163
+ def _bootstrap() -> FastAPI:
164
+ source = os.environ.get("CI_TRIAGE_SCENARIO_SOURCE")
165
+ env = CITriageEnv(scenario_source=source)
166
+ return create_app(env)
167
+
168
+
169
+ if __name__ == "__main__":
170
+ import uvicorn
171
+
172
+ uvicorn.run(_bootstrap(), host="0.0.0.0", port=8000)
src/ci_triage_env/env/tools/__init__.py CHANGED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ci_triage_env.env.tools.actions import (
2
+ FileBugHandler,
3
+ PingOwnerHandler,
4
+ QuarantineTestHandler,
5
+ RerunTestHandler,
6
+ )
7
+ from ci_triage_env.env.tools.base import ToolHandler
8
+ from ci_triage_env.env.tools.context import (
9
+ CheckOwnerHandler,
10
+ QueryFlakeHistoryHandler,
11
+ RecentCommitsHandler,
12
+ )
13
+ from ci_triage_env.env.tools.investigation import (
14
+ ClusterMetricsHandler,
15
+ InspectTestCodeHandler,
16
+ ReadLogsHandler,
17
+ RunDiagnosticHandler,
18
+ )
19
+
20
+ ALL_TOOL_HANDLERS: list[ToolHandler] = [
21
+ ReadLogsHandler(),
22
+ InspectTestCodeHandler(),
23
+ RunDiagnosticHandler(),
24
+ ClusterMetricsHandler(),
25
+ QueryFlakeHistoryHandler(),
26
+ RecentCommitsHandler(),
27
+ CheckOwnerHandler(),
28
+ RerunTestHandler(),
29
+ QuarantineTestHandler(),
30
+ FileBugHandler(),
31
+ PingOwnerHandler(),
32
+ ]
33
+
34
+ __all__ = [
35
+ "ALL_TOOL_HANDLERS",
36
+ "CheckOwnerHandler",
37
+ "ClusterMetricsHandler",
38
+ "FileBugHandler",
39
+ "InspectTestCodeHandler",
40
+ "PingOwnerHandler",
41
+ "QuarantineTestHandler",
42
+ "QueryFlakeHistoryHandler",
43
+ "ReadLogsHandler",
44
+ "RecentCommitsHandler",
45
+ "RerunTestHandler",
46
+ "RunDiagnosticHandler",
47
+ "ToolHandler",
48
+ ]
src/ci_triage_env/env/tools/actions.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import ClassVar
2
+
3
+ from ci_triage_env.env.tools.investigation import _StubToolHandler
4
+ from ci_triage_env.schemas.tools import ALL_TOOLS
5
+
6
+ _TOOL_DEFS = {t.name: t for t in ALL_TOOLS}
7
+
8
+
9
+ class RerunTestHandler(_StubToolHandler):
10
+ name: ClassVar[str] = "rerun_test"
11
+ cost_unit: ClassVar[float] = _TOOL_DEFS["rerun_test"].cost_unit
12
+
13
+
14
+ class QuarantineTestHandler(_StubToolHandler):
15
+ name: ClassVar[str] = "quarantine_test"
16
+ cost_unit: ClassVar[float] = _TOOL_DEFS["quarantine_test"].cost_unit
17
+
18
+
19
+ class FileBugHandler(_StubToolHandler):
20
+ name: ClassVar[str] = "file_bug"
21
+ cost_unit: ClassVar[float] = _TOOL_DEFS["file_bug"].cost_unit
22
+
23
+
24
+ class PingOwnerHandler(_StubToolHandler):
25
+ name: ClassVar[str] = "ping_owner"
26
+ cost_unit: ClassVar[float] = _TOOL_DEFS["ping_owner"].cost_unit
src/ci_triage_env/env/tools/context.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import ClassVar
2
+
3
+ from ci_triage_env.env.tools.investigation import _StubToolHandler
4
+ from ci_triage_env.schemas.tools import ALL_TOOLS
5
+
6
+ _TOOL_DEFS = {t.name: t for t in ALL_TOOLS}
7
+
8
+
9
+ class QueryFlakeHistoryHandler(_StubToolHandler):
10
+ name: ClassVar[str] = "query_flake_history"
11
+ cost_unit: ClassVar[float] = _TOOL_DEFS["query_flake_history"].cost_unit
12
+
13
+
14
+ class RecentCommitsHandler(_StubToolHandler):
15
+ name: ClassVar[str] = "recent_commits"
16
+ cost_unit: ClassVar[float] = _TOOL_DEFS["recent_commits"].cost_unit
17
+
18
+
19
+ class CheckOwnerHandler(_StubToolHandler):
20
+ name: ClassVar[str] = "check_owner"
21
+ cost_unit: ClassVar[float] = _TOOL_DEFS["check_owner"].cost_unit
src/ci_triage_env/env/tools/investigation.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import ClassVar
2
+
3
+ import jsonschema
4
+
5
+ from ci_triage_env.env.tools.base import ToolHandler
6
+ from ci_triage_env.schemas.episode import StepRecord
7
+ from ci_triage_env.schemas.scenario import Scenario, ToolOutput
8
+ from ci_triage_env.schemas.tools import ALL_TOOLS
9
+
10
+ _TOOL_DEFS = {t.name: t for t in ALL_TOOLS}
11
+
12
+
13
+ class _StubToolHandler(ToolHandler):
14
+ """Phase A1 stub. Validates args against MCPToolDef.args_schema, returns placeholder payload."""
15
+
16
+ name: ClassVar[str] = ""
17
+ cost_unit: ClassVar[float] = 0.0
18
+
19
+ def validate_args(self, args: dict) -> None:
20
+ spec = _TOOL_DEFS[self.name]
21
+ try:
22
+ jsonschema.validate(instance=args, schema=spec.args_schema)
23
+ except jsonschema.ValidationError as exc:
24
+ raise ValueError(f"invalid args for {self.name}: {exc.message}") from exc
25
+
26
+ def call(
27
+ self,
28
+ args: dict,
29
+ scenario: Scenario,
30
+ history: list[StepRecord],
31
+ ) -> ToolOutput:
32
+ self.validate_args(args)
33
+ return ToolOutput(
34
+ tool_name=self.name,
35
+ payload={"stub": True, "tool": self.name},
36
+ cost_units=self.cost_unit,
37
+ )
38
+
39
+
40
+ class ReadLogsHandler(_StubToolHandler):
41
+ name: ClassVar[str] = "read_logs"
42
+ cost_unit: ClassVar[float] = _TOOL_DEFS["read_logs"].cost_unit
43
+
44
+ def call(
45
+ self,
46
+ args: dict,
47
+ scenario: Scenario,
48
+ history: list[StepRecord],
49
+ ) -> ToolOutput:
50
+ self.validate_args(args)
51
+ return ToolOutput(
52
+ tool_name=self.name,
53
+ payload={"lines": ["[stub]"], "truncated": False},
54
+ cost_units=self.cost_unit,
55
+ )
56
+
57
+
58
+ class InspectTestCodeHandler(_StubToolHandler):
59
+ name: ClassVar[str] = "inspect_test_code"
60
+ cost_unit: ClassVar[float] = _TOOL_DEFS["inspect_test_code"].cost_unit
61
+
62
+
63
+ class RunDiagnosticHandler(_StubToolHandler):
64
+ name: ClassVar[str] = "run_diagnostic"
65
+ cost_unit: ClassVar[float] = _TOOL_DEFS["run_diagnostic"].cost_unit
66
+
67
+
68
+ class ClusterMetricsHandler(_StubToolHandler):
69
+ name: ClassVar[str] = "cluster_metrics"
70
+ cost_unit: ClassVar[float] = _TOOL_DEFS["cluster_metrics"].cost_unit
tests/env/__init__.py ADDED
File without changes
tests/env/conftest.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from fastapi.testclient import TestClient
3
+
4
+ from ci_triage_env.env.server import CITriageEnv, create_app
5
+ from ci_triage_env.mock.scenario import make_mock_scenario
6
+
7
+
8
+ @pytest.fixture
9
+ def env() -> CITriageEnv:
10
+ scenarios = {
11
+ s.scenario_id: s
12
+ for s in [
13
+ make_mock_scenario("race_flake", seed=42),
14
+ make_mock_scenario("real_bug", seed=7),
15
+ ]
16
+ }
17
+ return CITriageEnv(scenarios=scenarios)
18
+
19
+
20
+ @pytest.fixture
21
+ def client(env: CITriageEnv) -> TestClient:
22
+ return TestClient(create_app(env))
23
+
24
+
25
+ @pytest.fixture
26
+ def known_scenario_id(env: CITriageEnv) -> str:
27
+ return next(iter(env.scenarios))
tests/env/test_server.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+
3
+ from fastapi.testclient import TestClient
4
+
5
+ from ci_triage_env.env.server import CITriageEnv, create_app
6
+ from ci_triage_env.mock.scenario import make_mock_scenario
7
+ from ci_triage_env.schemas.action import TerminalAction, ToolCall
8
+ from ci_triage_env.schemas.diagnosis import DiagnosisLabel
9
+ from ci_triage_env.schemas.episode import EpisodeState
10
+ from ci_triage_env.schemas.observation import Observation
11
+
12
+
13
+ def test_server_boots():
14
+ env = CITriageEnv(scenarios={make_mock_scenario().scenario_id: make_mock_scenario()})
15
+ app = create_app(env)
16
+ assert app.title == "CI Triage Env"
17
+
18
+
19
+ def test_reset_returns_valid_observation(client: TestClient):
20
+ resp = client.post("/reset", json={})
21
+ assert resp.status_code == 200
22
+ obs = Observation.model_validate(resp.json())
23
+ assert obs.failure_summary is not None
24
+ assert obs.step == 0
25
+ assert obs.is_terminal is False
26
+
27
+
28
+ def test_reset_with_specific_scenario_id(client: TestClient, known_scenario_id: str):
29
+ resp = client.post("/reset", json={"scenario_id": known_scenario_id})
30
+ assert resp.status_code == 200
31
+ obs = Observation.model_validate(resp.json())
32
+ assert obs.episode_id
33
+
34
+
35
+ def test_reset_with_unknown_scenario_id_404(client: TestClient):
36
+ resp = client.post("/reset", json={"scenario_id": "does-not-exist"})
37
+ assert resp.status_code == 404
38
+
39
+
40
+ def test_step_with_tool_call_returns_observation(client: TestClient, known_scenario_id: str):
41
+ reset = client.post("/reset", json={"scenario_id": known_scenario_id}).json()
42
+ episode_id = reset["episode_id"]
43
+ call = ToolCall(tool_name="read_logs", args={"scope": "test"})
44
+ resp = client.post("/step", json={"episode_id": episode_id, "action": call.model_dump()})
45
+ assert resp.status_code == 200, resp.text
46
+ obs = Observation.model_validate(resp.json())
47
+ assert obs.tool_response is not None
48
+ assert obs.tool_response.tool_name == "read_logs"
49
+ assert obs.is_terminal is False
50
+
51
+
52
+ def test_step_with_terminal_action_marks_done(client: TestClient, known_scenario_id: str):
53
+ reset = client.post("/reset", json={"scenario_id": known_scenario_id}).json()
54
+ episode_id = reset["episode_id"]
55
+ terminal = TerminalAction(
56
+ action_type="submit_diagnosis",
57
+ diagnosis=DiagnosisLabel.RACE_FLAKE,
58
+ confidence=0.8,
59
+ )
60
+ resp = client.post("/step", json={"episode_id": episode_id, "action": terminal.model_dump()})
61
+ assert resp.status_code == 200, resp.text
62
+ obs = Observation.model_validate(resp.json())
63
+ assert obs.is_terminal is True
64
+
65
+ state_resp = client.get(f"/state/{episode_id}")
66
+ state = EpisodeState.model_validate(state_resp.json())
67
+ assert state.is_terminated is True
68
+ assert state.final_action is not None
69
+
70
+
71
+ def test_step_after_terminal_returns_400(client: TestClient, known_scenario_id: str):
72
+ reset = client.post("/reset", json={"scenario_id": known_scenario_id}).json()
73
+ episode_id = reset["episode_id"]
74
+ terminal = TerminalAction(
75
+ action_type="submit_diagnosis",
76
+ diagnosis=DiagnosisLabel.RACE_FLAKE,
77
+ confidence=0.8,
78
+ )
79
+ client.post("/step", json={"episode_id": episode_id, "action": terminal.model_dump()})
80
+ again = client.post("/step", json={"episode_id": episode_id, "action": terminal.model_dump()})
81
+ assert again.status_code == 400
82
+
83
+
84
+ def test_state_endpoint_returns_episode_state(client: TestClient, known_scenario_id: str):
85
+ reset = client.post("/reset", json={"scenario_id": known_scenario_id}).json()
86
+ episode_id = reset["episode_id"]
87
+ resp = client.get(f"/state/{episode_id}")
88
+ assert resp.status_code == 200
89
+ state = EpisodeState.model_validate(resp.json())
90
+ assert state.episode_id == episode_id
91
+ assert state.scenario_id == known_scenario_id
92
+ assert state.step == 0
93
+ assert state.is_terminated is False
94
+
95
+
96
+ def test_state_unknown_episode_404(client: TestClient):
97
+ resp = client.get("/state/not-a-real-episode-id")
98
+ assert resp.status_code == 404
99
+
100
+
101
+ def test_concurrent_resets_get_distinct_episode_ids(client: TestClient, known_scenario_id: str):
102
+ def do_reset() -> str:
103
+ return client.post("/reset", json={"scenario_id": known_scenario_id}).json()["episode_id"]
104
+
105
+ with ThreadPoolExecutor(max_workers=8) as pool:
106
+ ids = list(pool.map(lambda _: do_reset(), range(8)))
107
+
108
+ assert len(set(ids)) == len(ids)
109
+
110
+
111
+ def test_mcp_endpoint_lists_all_11_tools(client: TestClient):
112
+ resp = client.get("/mcp/tools")
113
+ assert resp.status_code == 200
114
+ tools = resp.json()
115
+ names = {t["name"] for t in tools}
116
+ assert names == {
117
+ "read_logs",
118
+ "inspect_test_code",
119
+ "run_diagnostic",
120
+ "cluster_metrics",
121
+ "query_flake_history",
122
+ "recent_commits",
123
+ "check_owner",
124
+ "rerun_test",
125
+ "quarantine_test",
126
+ "file_bug",
127
+ "ping_owner",
128
+ }
129
+ assert len(tools) == 11
130
+
131
+
132
+ def test_episode_seeding_deterministic(client: TestClient, known_scenario_id: str):
133
+ def run_one() -> EpisodeState:
134
+ reset = client.post(
135
+ "/reset",
136
+ json={"scenario_id": known_scenario_id, "seed_override": 12345},
137
+ ).json()
138
+ episode_id = reset["episode_id"]
139
+ call = ToolCall(tool_name="read_logs", args={"scope": "test"})
140
+ client.post("/step", json={"episode_id": episode_id, "action": call.model_dump()})
141
+ terminal = TerminalAction(
142
+ action_type="submit_diagnosis",
143
+ diagnosis=DiagnosisLabel.RACE_FLAKE,
144
+ confidence=0.6,
145
+ )
146
+ client.post("/step", json={"episode_id": episode_id, "action": terminal.model_dump()})
147
+ return EpisodeState.model_validate(client.get(f"/state/{episode_id}").json())
148
+
149
+ a = run_one()
150
+ b = run_one()
151
+ assert a.seed == b.seed == 12345
152
+ assert a.step == b.step
153
+ assert a.is_terminated and b.is_terminated
154
+ assert [r.action for r in a.history] == [r.action for r in b.history]
155
+ assert [r.cost_charged for r in a.history] == [r.cost_charged for r in b.history]
uv.lock CHANGED
@@ -198,6 +198,7 @@ dependencies = [
198
  { name = "httpx" },
199
  { name = "huggingface-hub" },
200
  { name = "jsonschema" },
 
201
  { name = "pydantic" },
202
  { name = "pyyaml" },
203
  { name = "uvicorn", extra = ["standard"] },
@@ -239,6 +240,7 @@ requires-dist = [
239
  { name = "matplotlib", marker = "extra == 'training'", specifier = ">=3.8" },
240
  { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.10" },
241
  { name = "openai", marker = "extra == 'data'", specifier = ">=1.40" },
 
242
  { name = "pandas", marker = "extra == 'training'", specifier = ">=2.2" },
243
  { name = "pydantic", specifier = ">=2.7,<3.0" },
244
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=8" },
@@ -1304,6 +1306,18 @@ wheels = [
1304
  { url = "https://files.pythonhosted.org/packages/1e/c1/d6e64ccd0536bf616556f0cad2b6d94a8125f508d25cfd814b1d2db4e2f1/openai-2.32.0-py3-none-any.whl", hash = "sha256:4dcc9badeb4bf54ad0d187453742f290226d30150890b7890711bda4f32f192f", size = 1162570, upload-time = "2026-04-15T22:28:17.714Z" },
1305
  ]
1306
 
 
 
 
 
 
 
 
 
 
 
 
 
1307
  [[package]]
1308
  name = "packaging"
1309
  version = "26.2"
 
198
  { name = "httpx" },
199
  { name = "huggingface-hub" },
200
  { name = "jsonschema" },
201
+ { name = "openenv" },
202
  { name = "pydantic" },
203
  { name = "pyyaml" },
204
  { name = "uvicorn", extra = ["standard"] },
 
240
  { name = "matplotlib", marker = "extra == 'training'", specifier = ">=3.8" },
241
  { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.10" },
242
  { name = "openai", marker = "extra == 'data'", specifier = ">=1.40" },
243
+ { name = "openenv", specifier = ">=0.1.13" },
244
  { name = "pandas", marker = "extra == 'training'", specifier = ">=2.2" },
245
  { name = "pydantic", specifier = ">=2.7,<3.0" },
246
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=8" },
 
1306
  { url = "https://files.pythonhosted.org/packages/1e/c1/d6e64ccd0536bf616556f0cad2b6d94a8125f508d25cfd814b1d2db4e2f1/openai-2.32.0-py3-none-any.whl", hash = "sha256:4dcc9badeb4bf54ad0d187453742f290226d30150890b7890711bda4f32f192f", size = 1162570, upload-time = "2026-04-15T22:28:17.714Z" },
1307
  ]
1308
 
1309
+ [[package]]
1310
+ name = "openenv"
1311
+ version = "0.1.13"
1312
+ source = { registry = "https://pypi.org/simple" }
1313
+ dependencies = [
1314
+ { name = "numpy" },
1315
+ ]
1316
+ sdist = { url = "https://files.pythonhosted.org/packages/35/94/c47e8f7303452793a3519c8cbc1b31dfffdedd13aaed821958ab3f152927/openenv-0.1.13.tar.gz", hash = "sha256:726971d2289472c1c20261436bcccdf3edfcf0b201d16aec127815bd83bfcb3d", size = 5112, upload-time = "2020-12-16T11:49:39.777Z" }
1317
+ wheels = [
1318
+ { url = "https://files.pythonhosted.org/packages/33/7f/e6f4467528161b8f0eb2ec784f4bbcd1fa9ea7acad13c0fb18597013e83b/openenv-0.1.13-py3-none-any.whl", hash = "sha256:813249d7f526f40c6e8b325f705294761a5bc887b9144c3383fa2bae7baa7726", size = 12080, upload-time = "2020-12-16T11:49:38.816Z" },
1319
+ ]
1320
+
1321
  [[package]]
1322
  name = "packaging"
1323
  version = "26.2"