File size: 2,967 Bytes
df97e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
tests/test_story_router.py
Tests for all 7 /training/* endpoints.
Requires: data/training_logs/mixed_urgency_medium_training_log.json
"""

import pytest
from fastapi.testclient import TestClient
from app.main import app

client = TestClient(app)
TASK = "mixed_urgency_medium"


def test_list_tasks():
    r = client.get("/training/tasks")
    assert r.status_code == 200
    data = r.json()
    assert "tasks" in data
    assert isinstance(data["tasks"], list)


def test_summary():
    r = client.get(f"/training/summary/{TASK}")
    assert r.status_code == 200
    data = r.json()
    assert data["task_id"] == TASK
    assert "summary" in data
    assert "narrative" in data
    assert "phase_1" in data["narrative"]
    assert "phase_4" in data["narrative"]


def test_curve_full():
    r = client.get(f"/training/curve/{TASK}")
    assert r.status_code == 200
    data = r.json()
    assert "curve" in data
    assert len(data["curve"]) > 0
    ep = data["curve"][0]
    assert "episode" in ep
    assert "reward" in ep
    assert "score" in ep
    assert "phase" in ep


def test_curve_downsample():
    r = client.get(f"/training/curve/{TASK}?downsample=5")
    assert r.status_code == 200
    data = r.json()
    assert data["total_points"] <= 100000


def test_actions():
    r = client.get(f"/training/actions/{TASK}")
    assert r.status_code == 200
    data = r.json()
    assert "checkpoints" in data
    assert len(data["checkpoints"]) == 5
    assert "insight" in data


def test_episode_first():
    r = client.get(f"/training/episode/{TASK}/1")
    assert r.status_code == 200
    data = r.json()
    assert data["episode"] == 1
    assert "reward" in data
    assert "score" in data
    assert "fn1_valid" in data
    assert "fn2_no_halluc" in data
    assert "fn3_env_score" in data
    assert "message" in data
    assert "running_best_reward" in data


def test_episode_last():
    # Get total to know last episode
    summary = client.get(f"/training/summary/{TASK}").json()
    total = summary["total_episodes"]
    r = client.get(f"/training/episode/{TASK}/{total}")
    assert r.status_code == 200


def test_episode_out_of_range():
    r = client.get(f"/training/episode/{TASK}/99999")
    assert r.status_code == 400


def test_comparison():
    r = client.get(f"/training/comparison/{TASK}")
    assert r.status_code == 200
    data = r.json()
    assert "before" in data
    assert "after" in data
    assert "improvement" in data
    assert "verdict" in data["improvement"]
    assert data["before"]["score"] > 0
    assert data["after"]["score"] > 0


def test_missing_task_404():
    r = client.get("/training/summary/nonexistent_task_xyz")
    assert r.status_code == 404


def test_stream_headers():
    # Test SSE endpoint returns correct content-type
    with client.stream("GET", f"/training/stream/{TASK}?delay_ms=0") as r:
        assert r.status_code == 200
        assert "text/event-stream" in r.headers["content-type"]