Spaces:
Running
Running
| """ | |
| tests/test_story_router.py | |
| Tests for all 7 /training/* endpoints. | |
| Requires: data/training_logs/mixed_urgency_medium_training_log.json | |
| """ | |
| import pytest | |
| from fastapi.testclient import TestClient | |
| from app.main import app | |
| client = TestClient(app) | |
| TASK = "mixed_urgency_medium" | |
| def test_list_tasks(): | |
| r = client.get("/training/tasks") | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert "tasks" in data | |
| assert isinstance(data["tasks"], list) | |
| def test_summary(): | |
| r = client.get(f"/training/summary/{TASK}") | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert data["task_id"] == TASK | |
| assert "summary" in data | |
| assert "narrative" in data | |
| assert "phase_1" in data["narrative"] | |
| assert "phase_4" in data["narrative"] | |
| def test_curve_full(): | |
| r = client.get(f"/training/curve/{TASK}") | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert "curve" in data | |
| assert len(data["curve"]) > 0 | |
| ep = data["curve"][0] | |
| assert "episode" in ep | |
| assert "reward" in ep | |
| assert "score" in ep | |
| assert "phase" in ep | |
| def test_curve_downsample(): | |
| r = client.get(f"/training/curve/{TASK}?downsample=5") | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert data["total_points"] <= 100000 | |
| def test_actions(): | |
| r = client.get(f"/training/actions/{TASK}") | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert "checkpoints" in data | |
| assert len(data["checkpoints"]) == 5 | |
| assert "insight" in data | |
| def test_episode_first(): | |
| r = client.get(f"/training/episode/{TASK}/1") | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert data["episode"] == 1 | |
| assert "reward" in data | |
| assert "score" in data | |
| assert "fn1_valid" in data | |
| assert "fn2_no_halluc" in data | |
| assert "fn3_env_score" in data | |
| assert "message" in data | |
| assert "running_best_reward" in data | |
| def test_episode_last(): | |
| # Get total to know last episode | |
| summary = client.get(f"/training/summary/{TASK}").json() | |
| total = summary["total_episodes"] | |
| r = client.get(f"/training/episode/{TASK}/{total}") | |
| assert r.status_code == 200 | |
| def test_episode_out_of_range(): | |
| r = client.get(f"/training/episode/{TASK}/99999") | |
| assert r.status_code == 400 | |
| def test_comparison(): | |
| r = client.get(f"/training/comparison/{TASK}") | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert "before" in data | |
| assert "after" in data | |
| assert "improvement" in data | |
| assert "verdict" in data["improvement"] | |
| assert data["before"]["score"] > 0 | |
| assert data["after"]["score"] > 0 | |
| def test_missing_task_404(): | |
| r = client.get("/training/summary/nonexistent_task_xyz") | |
| assert r.status_code == 404 | |
| def test_stream_headers(): | |
| # Test SSE endpoint returns correct content-type | |
| with client.stream("GET", f"/training/stream/{TASK}?delay_ms=0") as r: | |
| assert r.status_code == 200 | |
| assert "text/event-stream" in r.headers["content-type"] | |