EntropyEnv / tests /test_endpoints.py
immortalindeed's picture
first commit
4ec75cf
# tests/test_endpoints.py
# Basic endpoint tests for the environment.
# Run: python -m pytest tests/ -v
import requests
import pytest
BASE_URL = 'http://localhost:7860'
def test_health_check():
"""GET / should return 200 with status ok."""
r = requests.get(f'{BASE_URL}/')
assert r.status_code == 200
data = r.json()
assert data['status'] == 'ok'
assert data['tasks'] == 9
def test_reset_valid_task():
"""POST /reset with valid task_id should return episode_id and observation."""
r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'sec_easy'})
assert r.status_code == 200
data = r.json()
assert 'episode_id' in data
assert 'observation' in data
assert data['observation']['task_type'] == 'security'
def test_reset_all_tasks():
"""POST /reset should work for all 9 task IDs."""
tasks = [
'sec_easy', 'sec_medium', 'sec_hard',
'dep_easy', 'dep_medium', 'dep_hard',
'cli_easy', 'cli_medium', 'cli_hard',
]
for task_id in tasks:
r = requests.post(f'{BASE_URL}/reset', json={'task_id': task_id})
assert r.status_code == 200
data = r.json()
assert 'episode_id' in data, f'No episode_id for {task_id}'
assert 'observation' in data, f'No observation for {task_id}'
def test_reset_invalid_task():
"""POST /reset with invalid task_id should still return 200."""
r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'nonexistent'})
assert r.status_code == 200
def test_step_valid_action():
"""POST /step with valid action should return reward and observation."""
# Reset first
r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'sec_easy'})
ep_id = r.json()['episode_id']
# Step
action = {
'episode_id': ep_id,
'action_type': 'identify_vulnerability',
'vuln_type': 'sql_injection',
'cvss_score': 9.1,
'severity': 'critical',
'affected_line': 1,
}
r = requests.post(f'{BASE_URL}/step', json=action)
assert r.status_code == 200
data = r.json()
assert 'reward' in data
assert 'done' in data
assert 'observation' in data
assert 0.0 <= data['reward'] <= 1.0
def test_step_invalid_episode():
"""POST /step with invalid episode_id should return 200 with done=True."""
r = requests.post(f'{BASE_URL}/step', json={
'episode_id': 'nonexistent',
'action_type': 'identify_vulnerability',
})
assert r.status_code == 200
data = r.json()
assert data['done'] is True
def test_state_endpoint():
"""GET /state should return episode info."""
r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'sec_easy'})
ep_id = r.json()['episode_id']
r = requests.get(f'{BASE_URL}/state', params={'episode_id': ep_id})
assert r.status_code == 200
data = r.json()
assert data['episode_id'] == ep_id
assert data['done'] is False
def test_reward_range():
"""Rewards should always be in [0.0, 1.0]."""
tasks = ['sec_easy', 'dep_easy', 'cli_easy']
for task_id in tasks:
r = requests.post(f'{BASE_URL}/reset', json={'task_id': task_id})
ep_id = r.json()['episode_id']
# Send an invalid action
r = requests.post(f'{BASE_URL}/step', json={
'episode_id': ep_id,
'action_type': 'invalid_action_type',
})
data = r.json()
assert 0.0 <= data['reward'] <= 1.0, f'Reward out of range for {task_id}'
def test_step_enriched_observation():
"""Step observations should include task context fields."""
r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'sec_easy'})
ep_id = r.json()['episode_id']
action = {
'episode_id': ep_id,
'action_type': 'identify_vulnerability',
'vuln_type': 'sql_injection',
'cvss_score': 9.1,
'severity': 'critical',
'affected_line': 1,
}
r = requests.post(f'{BASE_URL}/step', json=action)
obs = r.json()['observation']
assert 'task_type' in obs
assert 'max_steps' in obs
assert 'steps_remaining' in obs