Spaces:
Sleeping
Sleeping
File size: 9,819 Bytes
272a052 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 | """Phase A2 tool-handler tests.
Per the phase doc, each of the 11 tools must:
1. return a valid ``ToolOutput`` for valid args (with the correct cost charged)
2. raise ``ValueError`` for invalid args
3. return an empty / no-signal payload when ``scenario.tool_outputs`` lacks the
expected key (rather than crashing)
4. return identical output on a repeated call with identical args
Plus two integration tests: a full 11-tool sweep and cumulative cost charging.
"""
from __future__ import annotations
import json
import pytest
from fastapi.testclient import TestClient
from ci_triage_env.env.server import CITriageEnv, build_app
from ci_triage_env.env.tools import (
CheckOwnerHandler,
ClusterMetricsHandler,
FileBugHandler,
InspectTestCodeHandler,
PingOwnerHandler,
QuarantineTestHandler,
QueryFlakeHistoryHandler,
ReadLogsHandler,
RecentCommitsHandler,
RerunTestHandler,
RunDiagnosticHandler,
)
from ci_triage_env.env.tools.utils import args_hash, deterministic_rng
from ci_triage_env.env.wire import CITriageAction
from ci_triage_env.schemas.scenario import Scenario, ToolOutput
from tests.env.conftest import make_a2_scenario
# (handler_factory, valid_args, invalid_args, missing_key_check, expected_cost_for_valid)
TOOL_MATRIX = [
pytest.param(
ReadLogsHandler,
{"scope": "test", "lines": 100},
{"scope": "bogus_scope"},
"read_logs:test",
0.001 * 100 / 100,
id="read_logs",
),
pytest.param(
InspectTestCodeHandler,
{"test_name": "tests/unit/test_widget.py::test_concurrent_update", "include_fixtures": True},
{"include_fixtures": True}, # missing required test_name
"inspect_test_code:tests/unit/test_widget.py::test_concurrent_update",
0.05,
id="inspect_test_code",
),
pytest.param(
RunDiagnosticHandler,
{"probe": "network"},
{"probe": "not_a_probe"},
"run_diagnostic:network",
0.10,
id="run_diagnostic",
),
pytest.param(
ClusterMetricsHandler,
{"metric": "queue_depth", "window_minutes": 15},
{"metric": "no_such_metric"},
"cluster_metrics:queue_depth",
0.02,
id="cluster_metrics",
),
pytest.param(
QueryFlakeHistoryHandler,
{"test_name": "tests/unit/test_widget.py::test_concurrent_update"},
{}, # missing test_name
"query_flake_history:tests/unit/test_widget.py::test_concurrent_update",
0.01,
id="query_flake_history",
),
pytest.param(
RecentCommitsHandler,
{"branch": "main", "limit": 5},
{"limit": 5}, # missing branch
"recent_commits:main",
0.01,
id="recent_commits",
),
pytest.param(
CheckOwnerHandler,
{"target": "tests/unit/test_widget.py"},
{}, # missing target
"check_owner:tests/unit/test_widget.py",
0.01,
id="check_owner",
),
pytest.param(
RerunTestHandler,
{"test_name": "tests/unit/test_widget.py::test_concurrent_update", "iterations": 2},
{"iterations": 2}, # missing test_name
"rerun_test",
0.30,
id="rerun_test",
),
pytest.param(
QuarantineTestHandler,
{"test_name": "tests/unit/test_widget.py::test_concurrent_update", "reason": "flake"},
{"test_name": "x"}, # missing reason
"quarantine_test",
0.0,
id="quarantine_test",
),
pytest.param(
FileBugHandler,
{"title": "t", "summary": "s", "owner": "alice", "severity": "high"},
{"title": "t", "summary": "s", "owner": "alice"}, # missing severity
"file_bug",
0.5,
id="file_bug",
),
pytest.param(
PingOwnerHandler,
{"owner": "alice", "message": "hey"},
{"owner": "alice"}, # missing message
"ping_owner",
0.083,
id="ping_owner",
),
]
@pytest.fixture
def scenario() -> Scenario:
return make_a2_scenario()
@pytest.mark.parametrize("handler_cls,valid_args,invalid_args,key,expected_cost", TOOL_MATRIX)
def test_tool_valid_args_returns_output(handler_cls, valid_args, invalid_args, key, expected_cost, scenario):
handler = handler_cls()
out = handler.call(valid_args, scenario, [])
assert isinstance(out, ToolOutput)
assert out.tool_name == handler.name
assert out.cost_units == pytest.approx(expected_cost)
@pytest.mark.parametrize("handler_cls,valid_args,invalid_args,key,expected_cost", TOOL_MATRIX)
def test_tool_invalid_args_raises(handler_cls, valid_args, invalid_args, key, expected_cost, scenario):
handler = handler_cls()
with pytest.raises(ValueError):
handler.call(invalid_args, scenario, [])
@pytest.mark.parametrize("handler_cls,valid_args,invalid_args,key,expected_cost", TOOL_MATRIX)
def test_tool_missing_scenario_data_returns_empty(handler_cls, valid_args, invalid_args, key, expected_cost, scenario):
"""Strip the relevant key from scenario.tool_outputs and confirm the
handler returns a non-crashing empty payload."""
stripped = {k: v for k, v in scenario.tool_outputs.items() if k != key}
bare = scenario.model_copy(update={"tool_outputs": stripped})
handler = handler_cls()
out = handler.call(valid_args, bare, [])
assert isinstance(out, ToolOutput)
assert out.tool_name == handler.name
# Cost is still charged even when the scenario doesn't carry data.
assert out.cost_units >= 0.0
@pytest.mark.parametrize("handler_cls,valid_args,invalid_args,key,expected_cost", TOOL_MATRIX)
def test_tool_repeated_call_returns_same_output(handler_cls, valid_args, invalid_args, key, expected_cost, scenario):
handler = handler_cls()
first = handler.call(valid_args, scenario, [])
second = handler.call(valid_args, scenario, [])
assert first == second
# ---------------------------------------------------------------------------
# Read-logs cost scaling deserves a focused test (Phase A2 §implementation note)
# ---------------------------------------------------------------------------
def test_read_logs_cost_scales_with_lines(scenario):
h = ReadLogsHandler()
cheap = h.call({"scope": "test", "lines": 100}, scenario, [])
pricey = h.call({"scope": "test", "lines": 200}, scenario, [])
assert pricey.cost_units == pytest.approx(2.0 * cheap.cost_units)
def test_read_logs_truncates_when_lines_smaller_than_payload(scenario):
out = ReadLogsHandler().call({"scope": "test", "lines": 10}, scenario, [])
assert isinstance(out.payload, dict)
assert len(out.payload["lines"]) == 10
assert out.payload["truncated"] is True
# ---------------------------------------------------------------------------
# Integration: full 11-tool sweep + cumulative cost
# ---------------------------------------------------------------------------
def test_full_tool_loop_against_mock_scenario(scenario):
expected = {p.id for p in TOOL_MATRIX}
seen: set[str] = set()
for handler_cls, valid_args, _inv, _key, _cost in [p.values for p in TOOL_MATRIX]:
out = handler_cls().call(valid_args, scenario, [])
assert isinstance(out, ToolOutput)
seen.add(out.tool_name)
assert seen == expected
def test_cost_charging_accumulates_correctly(client, a2_scenario):
"""Drive the env over WS with a sequence of tool calls; verify the budget
deducts exactly the sum of each handler's reported ``cost_units``."""
# Inject our richer scenario into a fresh CITriageEnv via build_app.
factory = lambda: CITriageEnv(scenarios={a2_scenario.scenario_id: a2_scenario}) # noqa: E731
app = build_app(env_factory=factory)
c = TestClient(app)
sequence = [
CITriageAction.from_tool_call("read_logs", {"scope": "test", "lines": 100}),
CITriageAction.from_tool_call("query_flake_history", {"test_name": a2_scenario.failure_summary.test_name}),
CITriageAction.from_tool_call("recent_commits", {"branch": a2_scenario.failure_summary.branch, "limit": 5}),
CITriageAction.from_tool_call("rerun_test", {"test_name": a2_scenario.failure_summary.test_name, "iterations": 2}),
CITriageAction.from_tool_call("ping_owner", {"owner": "alice", "message": "hi"}),
]
expected_costs = [0.001, 0.01, 0.01, 0.30, 0.083]
with c.websocket_connect("/ws") as ws:
ws.send_text(json.dumps({"type": "reset", "data": {"scenario_id": a2_scenario.scenario_id}}))
initial = json.loads(ws.receive_text())
budget0 = initial["data"]["observation"]["payload"]["budget_remaining"]["cost_remaining"]
running = budget0
for action, expected in zip(sequence, expected_costs, strict=True):
ws.send_text(json.dumps({"type": "step", "data": action.model_dump()}))
resp = json.loads(ws.receive_text())
charged = resp["data"]["observation"]["payload"]["tool_response"]["cost_charged"]
assert charged == pytest.approx(expected)
running -= charged
assert resp["data"]["observation"]["payload"]["budget_remaining"]["cost_remaining"] == pytest.approx(running)
# ---------------------------------------------------------------------------
# utils smoke
# ---------------------------------------------------------------------------
def test_args_hash_is_stable_and_order_independent():
a = args_hash({"x": 1, "y": [1, 2]})
b = args_hash({"y": [1, 2], "x": 1})
assert a == b
assert a != args_hash({"x": 2, "y": [1, 2]})
def test_deterministic_rng_is_reproducible():
r1 = deterministic_rng(42, 3, "read_logs").random()
r2 = deterministic_rng(42, 3, "read_logs").random()
assert r1 == r2
assert r1 != deterministic_rng(42, 4, "read_logs").random()
|