Spaces:
Running
Running
| """ | |
| End-to-end API suite for the full endpoint contract. | |
| This suite focuses on: | |
| 1) endpoint availability | |
| 2) cross-endpoint data flow | |
| 3) session lifecycle correctness | |
| 4) simulation stream behavior | |
| 5) RL endpoint guardrails | |
| """ | |
| from __future__ import annotations | |
| from httpx import ASGITransport, AsyncClient | |
| from app.main import app | |
| from rl.feature_builder import N_ACTIONS | |
| BASE_URL = "http://test" | |
| REQUIRED_PATHS = { | |
| "/health", | |
| "/reset", | |
| "/step", | |
| "/state", | |
| "/simulate", | |
| "/simulate/{session_id}/snapshot", | |
| "/grade", | |
| "/tasks", | |
| "/tasks/{task_id}", | |
| "/action-masks", | |
| "/rl/run", | |
| "/rl/models", | |
| "/simulate/{session_id}/cancel", | |
| "/simulate/{session_id}/trace", | |
| "/actions/schema", | |
| "/metrics", | |
| } | |
| async def test_openapi_contains_all_required_endpoints() -> None: | |
| paths = set(app.openapi().get("paths", {}).keys()) | |
| assert REQUIRED_PATHS.issubset(paths), f"Missing paths: {sorted(REQUIRED_PATHS - paths)}" | |
| async def test_health_tasks_metrics_and_schema_consistency() -> None: | |
| async with AsyncClient(transport=ASGITransport(app=app), base_url=BASE_URL) as c: | |
| health = await c.get("/health") | |
| tasks = await c.get("/tasks") | |
| metrics = await c.get("/metrics") | |
| schema = await c.get("/actions/schema") | |
| assert health.status_code == 200 | |
| h = health.json() | |
| assert h["status"] in {"ok", "degraded"} | |
| assert h["version"] == "2.0.0" | |
| assert h["phase"] == "3_rl_training" | |
| assert tasks.status_code == 200 | |
| task_rows = tasks.json() | |
| assert isinstance(task_rows, list) | |
| assert len(task_rows) == 3 | |
| task_ids = {row["task_id"] for row in task_rows} | |
| assert task_ids == { | |
| "district_backlog_easy", | |
| "mixed_urgency_medium", | |
| "cross_department_hard", | |
| } | |
| assert metrics.status_code == 200 | |
| m = metrics.json() | |
| assert m["version"] == "2.0.0" | |
| assert m["phase"] == "3_rl_training" | |
| assert m["total_tasks"] == 3 | |
| assert set(m["tasks_available"]) == task_ids | |
| assert schema.status_code == 200 | |
| s = schema.json() | |
| assert s["total_action_types"] == 6 | |
| assert len(s["actions"]) == 6 | |
| async def test_per_task_details_and_unknown_task_404() -> None: | |
| known = [ | |
| "district_backlog_easy", | |
| "mixed_urgency_medium", | |
| "cross_department_hard", | |
| ] | |
| async with AsyncClient(transport=ASGITransport(app=app), base_url=BASE_URL) as c: | |
| for task_id in known: | |
| r = await c.get(f"/tasks/{task_id}") | |
| assert r.status_code == 200 | |
| row = r.json() | |
| assert row["task_id"] == task_id | |
| assert row["max_days"] > 0 | |
| assert row["officer_pool_total"] > 0 | |
| assert isinstance(row["services"], list) | |
| assert len(row["services"]) >= 1 | |
| bad = await c.get("/tasks/fake_task") | |
| assert bad.status_code == 404 | |
| async def test_session_data_flow_reset_masks_step_trace_snapshot_grade_cancel() -> None: | |
| async with AsyncClient(transport=ASGITransport(app=app), base_url=BASE_URL) as c: | |
| reset = await c.post("/reset", json={"task_id": "district_backlog_easy", "seed": 42}) | |
| assert reset.status_code == 200 | |
| reset_body = reset.json() | |
| sid = reset_body["session_id"] | |
| assert len(sid) == 36 | |
| assert reset_body["task_id"] == "district_backlog_easy" | |
| assert reset_body["observation"]["day"] == 0 | |
| masks = await c.post("/action-masks", json={"session_id": sid}) | |
| assert masks.status_code == 200 | |
| mask_body = masks.json() | |
| assert len(mask_body["action_mask"]) == N_ACTIONS | |
| assert mask_body["total_actions"] == N_ACTIONS | |
| assert mask_body["total_valid"] > 0 | |
| for _ in range(3): | |
| step = await c.post( | |
| "/step", | |
| json={"session_id": sid, "action": {"action_type": "advance_time"}}, | |
| ) | |
| assert step.status_code == 200 | |
| state = await c.get("/state", params={"session_id": sid, "include_action_history": True}) | |
| assert state.status_code == 200 | |
| st = state.json()["state"] | |
| assert st["day"] >= 1 | |
| assert st["action_history_count"] >= 3 | |
| trace_page1 = await c.get(f"/simulate/{sid}/trace", params={"page": 1, "page_size": 2}) | |
| trace_page2 = await c.get(f"/simulate/{sid}/trace", params={"page": 2, "page_size": 2}) | |
| assert trace_page1.status_code == 200 | |
| assert trace_page2.status_code == 200 | |
| p1 = trace_page1.json() | |
| p2 = trace_page2.json() | |
| assert p1["total_steps"] >= 3 | |
| assert len(p1["steps"]) == 2 | |
| assert p2["page"] == 2 | |
| assert len(p2["steps"]) >= 1 | |
| snap = await c.get(f"/simulate/{sid}/snapshot") | |
| assert snap.status_code == 200 | |
| snap_body = snap.json() | |
| assert snap_body["session_id"] == sid | |
| assert "observation" in snap_body | |
| grade = await c.post("/grade", json={"session_id": sid}) | |
| assert grade.status_code == 200 | |
| g = grade.json() | |
| assert g["task_id"] == "district_backlog_easy" | |
| assert 0.0 <= g["score"] <= 1.0 | |
| assert isinstance(g["metrics"], dict) | |
| cancel = await c.post(f"/simulate/{sid}/cancel") | |
| assert cancel.status_code == 200 | |
| assert cancel.json()["status"] == "cancelled" | |
| state_after = await c.get("/state", params={"session_id": sid}) | |
| assert state_after.status_code == 404 | |
| async def test_simulate_endpoint_validation_contract() -> None: | |
| async with AsyncClient(transport=ASGITransport(app=app), base_url=BASE_URL, timeout=30.0) as c: | |
| bad_task = await c.post( | |
| "/simulate", | |
| json={ | |
| "task_id": "not_a_real_task", | |
| "agent_mode": "baseline_policy", | |
| "max_steps": 3, | |
| "seed": 123, | |
| }, | |
| ) | |
| bad_mode = await c.post( | |
| "/simulate", | |
| json={ | |
| "task_id": "district_backlog_easy", | |
| "agent_mode": "wrong_mode", | |
| "max_steps": 3, | |
| "seed": 123, | |
| }, | |
| ) | |
| assert bad_task.status_code == 422 | |
| assert bad_mode.status_code == 422 | |
| async def test_rl_models_and_rl_run_missing_model_guardrail() -> None: | |
| async with AsyncClient(transport=ASGITransport(app=app), base_url=BASE_URL) as c: | |
| models = await c.get("/rl/models") | |
| assert models.status_code == 200 | |
| rows = models.json() | |
| assert isinstance(rows, list) | |
| assert len(rows) >= 1 | |
| for row in rows: | |
| assert "model_path" in row | |
| assert "exists" in row | |
| missing = await c.post( | |
| "/rl/run", | |
| json={ | |
| "task_id": "district_backlog_easy", | |
| "model_path": "results/best_model/does_not_exist", | |
| "seed": 42, | |
| "max_steps": 10, | |
| "n_episodes": 1, | |
| }, | |
| ) | |
| assert missing.status_code == 422 | |