ci-triage-env / tests /env /test_tools.py
Prasham.Jain
feat(branch-a): A2 tool implementations — route 11 handlers to scenario.tool_outputs with cost charging
272a052
"""Phase A2 tool-handler tests.
Per the phase doc, each of the 11 tools must:
1. return a valid ``ToolOutput`` for valid args (with the correct cost charged)
2. raise ``ValueError`` for invalid args
3. return an empty / no-signal payload when ``scenario.tool_outputs`` lacks the
expected key (rather than crashing)
4. return identical output on a repeated call with identical args
Plus two integration tests: a full 11-tool sweep and cumulative cost charging.
"""
from __future__ import annotations
import json
import pytest
from fastapi.testclient import TestClient
from ci_triage_env.env.server import CITriageEnv, build_app
from ci_triage_env.env.tools import (
CheckOwnerHandler,
ClusterMetricsHandler,
FileBugHandler,
InspectTestCodeHandler,
PingOwnerHandler,
QuarantineTestHandler,
QueryFlakeHistoryHandler,
ReadLogsHandler,
RecentCommitsHandler,
RerunTestHandler,
RunDiagnosticHandler,
)
from ci_triage_env.env.tools.utils import args_hash, deterministic_rng
from ci_triage_env.env.wire import CITriageAction
from ci_triage_env.schemas.scenario import Scenario, ToolOutput
from tests.env.conftest import make_a2_scenario
# (handler_factory, valid_args, invalid_args, missing_key_check, expected_cost_for_valid)
TOOL_MATRIX = [
pytest.param(
ReadLogsHandler,
{"scope": "test", "lines": 100},
{"scope": "bogus_scope"},
"read_logs:test",
0.001 * 100 / 100,
id="read_logs",
),
pytest.param(
InspectTestCodeHandler,
{"test_name": "tests/unit/test_widget.py::test_concurrent_update", "include_fixtures": True},
{"include_fixtures": True}, # missing required test_name
"inspect_test_code:tests/unit/test_widget.py::test_concurrent_update",
0.05,
id="inspect_test_code",
),
pytest.param(
RunDiagnosticHandler,
{"probe": "network"},
{"probe": "not_a_probe"},
"run_diagnostic:network",
0.10,
id="run_diagnostic",
),
pytest.param(
ClusterMetricsHandler,
{"metric": "queue_depth", "window_minutes": 15},
{"metric": "no_such_metric"},
"cluster_metrics:queue_depth",
0.02,
id="cluster_metrics",
),
pytest.param(
QueryFlakeHistoryHandler,
{"test_name": "tests/unit/test_widget.py::test_concurrent_update"},
{}, # missing test_name
"query_flake_history:tests/unit/test_widget.py::test_concurrent_update",
0.01,
id="query_flake_history",
),
pytest.param(
RecentCommitsHandler,
{"branch": "main", "limit": 5},
{"limit": 5}, # missing branch
"recent_commits:main",
0.01,
id="recent_commits",
),
pytest.param(
CheckOwnerHandler,
{"target": "tests/unit/test_widget.py"},
{}, # missing target
"check_owner:tests/unit/test_widget.py",
0.01,
id="check_owner",
),
pytest.param(
RerunTestHandler,
{"test_name": "tests/unit/test_widget.py::test_concurrent_update", "iterations": 2},
{"iterations": 2}, # missing test_name
"rerun_test",
0.30,
id="rerun_test",
),
pytest.param(
QuarantineTestHandler,
{"test_name": "tests/unit/test_widget.py::test_concurrent_update", "reason": "flake"},
{"test_name": "x"}, # missing reason
"quarantine_test",
0.0,
id="quarantine_test",
),
pytest.param(
FileBugHandler,
{"title": "t", "summary": "s", "owner": "alice", "severity": "high"},
{"title": "t", "summary": "s", "owner": "alice"}, # missing severity
"file_bug",
0.5,
id="file_bug",
),
pytest.param(
PingOwnerHandler,
{"owner": "alice", "message": "hey"},
{"owner": "alice"}, # missing message
"ping_owner",
0.083,
id="ping_owner",
),
]
@pytest.fixture
def scenario() -> Scenario:
return make_a2_scenario()
@pytest.mark.parametrize("handler_cls,valid_args,invalid_args,key,expected_cost", TOOL_MATRIX)
def test_tool_valid_args_returns_output(handler_cls, valid_args, invalid_args, key, expected_cost, scenario):
handler = handler_cls()
out = handler.call(valid_args, scenario, [])
assert isinstance(out, ToolOutput)
assert out.tool_name == handler.name
assert out.cost_units == pytest.approx(expected_cost)
@pytest.mark.parametrize("handler_cls,valid_args,invalid_args,key,expected_cost", TOOL_MATRIX)
def test_tool_invalid_args_raises(handler_cls, valid_args, invalid_args, key, expected_cost, scenario):
handler = handler_cls()
with pytest.raises(ValueError):
handler.call(invalid_args, scenario, [])
@pytest.mark.parametrize("handler_cls,valid_args,invalid_args,key,expected_cost", TOOL_MATRIX)
def test_tool_missing_scenario_data_returns_empty(handler_cls, valid_args, invalid_args, key, expected_cost, scenario):
"""Strip the relevant key from scenario.tool_outputs and confirm the
handler returns a non-crashing empty payload."""
stripped = {k: v for k, v in scenario.tool_outputs.items() if k != key}
bare = scenario.model_copy(update={"tool_outputs": stripped})
handler = handler_cls()
out = handler.call(valid_args, bare, [])
assert isinstance(out, ToolOutput)
assert out.tool_name == handler.name
# Cost is still charged even when the scenario doesn't carry data.
assert out.cost_units >= 0.0
@pytest.mark.parametrize("handler_cls,valid_args,invalid_args,key,expected_cost", TOOL_MATRIX)
def test_tool_repeated_call_returns_same_output(handler_cls, valid_args, invalid_args, key, expected_cost, scenario):
handler = handler_cls()
first = handler.call(valid_args, scenario, [])
second = handler.call(valid_args, scenario, [])
assert first == second
# ---------------------------------------------------------------------------
# Read-logs cost scaling deserves a focused test (Phase A2 §implementation note)
# ---------------------------------------------------------------------------
def test_read_logs_cost_scales_with_lines(scenario):
h = ReadLogsHandler()
cheap = h.call({"scope": "test", "lines": 100}, scenario, [])
pricey = h.call({"scope": "test", "lines": 200}, scenario, [])
assert pricey.cost_units == pytest.approx(2.0 * cheap.cost_units)
def test_read_logs_truncates_when_lines_smaller_than_payload(scenario):
out = ReadLogsHandler().call({"scope": "test", "lines": 10}, scenario, [])
assert isinstance(out.payload, dict)
assert len(out.payload["lines"]) == 10
assert out.payload["truncated"] is True
# ---------------------------------------------------------------------------
# Integration: full 11-tool sweep + cumulative cost
# ---------------------------------------------------------------------------
def test_full_tool_loop_against_mock_scenario(scenario):
expected = {p.id for p in TOOL_MATRIX}
seen: set[str] = set()
for handler_cls, valid_args, _inv, _key, _cost in [p.values for p in TOOL_MATRIX]:
out = handler_cls().call(valid_args, scenario, [])
assert isinstance(out, ToolOutput)
seen.add(out.tool_name)
assert seen == expected
def test_cost_charging_accumulates_correctly(client, a2_scenario):
"""Drive the env over WS with a sequence of tool calls; verify the budget
deducts exactly the sum of each handler's reported ``cost_units``."""
# Inject our richer scenario into a fresh CITriageEnv via build_app.
factory = lambda: CITriageEnv(scenarios={a2_scenario.scenario_id: a2_scenario}) # noqa: E731
app = build_app(env_factory=factory)
c = TestClient(app)
sequence = [
CITriageAction.from_tool_call("read_logs", {"scope": "test", "lines": 100}),
CITriageAction.from_tool_call("query_flake_history", {"test_name": a2_scenario.failure_summary.test_name}),
CITriageAction.from_tool_call("recent_commits", {"branch": a2_scenario.failure_summary.branch, "limit": 5}),
CITriageAction.from_tool_call("rerun_test", {"test_name": a2_scenario.failure_summary.test_name, "iterations": 2}),
CITriageAction.from_tool_call("ping_owner", {"owner": "alice", "message": "hi"}),
]
expected_costs = [0.001, 0.01, 0.01, 0.30, 0.083]
with c.websocket_connect("/ws") as ws:
ws.send_text(json.dumps({"type": "reset", "data": {"scenario_id": a2_scenario.scenario_id}}))
initial = json.loads(ws.receive_text())
budget0 = initial["data"]["observation"]["payload"]["budget_remaining"]["cost_remaining"]
running = budget0
for action, expected in zip(sequence, expected_costs, strict=True):
ws.send_text(json.dumps({"type": "step", "data": action.model_dump()}))
resp = json.loads(ws.receive_text())
charged = resp["data"]["observation"]["payload"]["tool_response"]["cost_charged"]
assert charged == pytest.approx(expected)
running -= charged
assert resp["data"]["observation"]["payload"]["budget_remaining"]["cost_remaining"] == pytest.approx(running)
# ---------------------------------------------------------------------------
# utils smoke
# ---------------------------------------------------------------------------
def test_args_hash_is_stable_and_order_independent():
a = args_hash({"x": 1, "y": [1, 2]})
b = args_hash({"y": [1, 2], "x": 1})
assert a == b
assert a != args_hash({"x": 2, "y": [1, 2]})
def test_deterministic_rng_is_reproducible():
r1 = deterministic_rng(42, 3, "read_logs").random()
r2 = deterministic_rng(42, 3, "read_logs").random()
assert r1 == r2
assert r1 != deterministic_rng(42, 4, "read_logs").random()