Spaces:
Running
Running
File size: 2,172 Bytes
df97e68 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | from fastapi.testclient import TestClient
from app.main import app
def test_simulation_history_persists_completed_runs() -> None:
client = TestClient(app)
run_resp = client.post(
"/api/simulation/run",
json={
"task_id": "district_backlog_easy",
"agent_mode": "baseline_policy",
"policy_name": "backlog_clearance",
"max_steps": 5,
"seed": 123,
},
)
assert run_resp.status_code == 200
history_resp = client.get("/api/history/simulations")
assert history_resp.status_code == 200
runs = history_resp.json().get("runs", [])
assert isinstance(runs, list)
assert any(row.get("task_id") == "district_backlog_easy" for row in runs)
run_id = next((row.get("run_id") for row in runs if row.get("run_id")), None)
assert run_id
detail_resp = client.get(f"/api/history/simulations/{run_id}")
assert detail_resp.status_code == 200
detail = detail_resp.json()
assert detail.get("run_id") == run_id
def test_comparison_history_roundtrip() -> None:
client = TestClient(app)
payload = {
"task_id": "district_backlog_easy",
"baseline_policy": "backlog_clearance",
"model_path": "results/best_model/phase2_final.zip",
"model_type": "maskable",
"include_llm": True,
"runs": 2,
"steps": 10,
"episodes": 1,
"seed_base": 100,
"result": {
"baselineScore": 0.6,
"trainedScore": 0.7,
"llmScore": 0.5,
},
}
create_resp = client.post("/api/history/comparisons", json=payload)
assert create_resp.status_code == 200
comparison_id = create_resp.json().get("comparison_id")
assert comparison_id
list_resp = client.get("/api/history/comparisons")
assert list_resp.status_code == 200
rows = list_resp.json().get("comparisons", [])
assert any(row.get("comparison_id") == comparison_id for row in rows)
detail_resp = client.get(f"/api/history/comparisons/{comparison_id}")
assert detail_resp.status_code == 200
assert detail_resp.json().get("comparison_id") == comparison_id
|